Newer
Older
Demo-Maker / modules / posenet / models / model_factory.py
@mikado-4410 mikado-4410 on 22 Nov 2024 756 bytes [update]LightGBMの読み込みを修正
import os

import torch

from modules.posenet.models.mobilenet_v1 import MOBILENET_V1_CHECKPOINTS, MobileNetV1

MODEL_DIR = "./models/posenet"
DEBUG_OUTPUT = False


def load_model(model_id, output_stride=16, model_dir=MODEL_DIR):
    model_path = os.path.join(model_dir, MOBILENET_V1_CHECKPOINTS[model_id] + ".pth")
    if not os.path.exists(model_path):
        print("Cannot find models file %s, converting from tfjs..." % model_path)
        from modules.posenet.converter.tfjs2pytorch import convert

        convert(model_id, model_dir, check=False)
        assert os.path.exists(model_path)

    model = MobileNetV1(model_id, output_stride=output_stride)
    load_dict = torch.load(model_path)
    model.load_state_dict(load_dict)

    return model