Newer
Older
Demo-Maker / modules / posenet / converter / tfjs2pytorch.py
@mikado-4410 mikado-4410 on 10 Oct 2024 3 KB 最初のコミット
import json
import struct
import cv2
import numpy as np
import os
import tempfile
import torch

from modules.posenet import MobileNetV1, MOBILENET_V1_CHECKPOINTS


BASE_DIR = os.path.join(tempfile.gettempdir(), "_posenet_weights")


def to_torch_name(tf_name):
    tf_name = tf_name.lower()
    tf_split = tf_name.split("/")
    tf_layer_split = tf_split[1].split("_")
    tf_variable_type = tf_split[2]
    if tf_variable_type == "weights" or tf_variable_type == "depthwise_weights":
        variable_postfix = ".weight"
    elif tf_variable_type == "biases":
        variable_postfix = ".bias"
    else:
        variable_postfix = ""

    if tf_layer_split[0] == "conv2d":
        torch_name = "features.conv" + tf_layer_split[1]
        if len(tf_layer_split) > 2:
            torch_name += "." + tf_layer_split[2]
        else:
            torch_name += ".conv"
        torch_name += variable_postfix
    else:
        if tf_layer_split[0] in ["offset", "displacement", "heatmap"] and tf_layer_split[-1] == "2":
            torch_name = "_".join(tf_layer_split[:-1])
            torch_name += variable_postfix
        else:
            torch_name = ""

    return torch_name


def load_variables(chkpoint, base_dir=BASE_DIR):
    manifest_path = os.path.join(base_dir, chkpoint, "manifest.json")
    if not os.path.exists(manifest_path):
        print("Weights for checkpoint %s are not downloaded. Downloading to %s ..." % (chkpoint, base_dir))
        from modules.posenet.converter.wget import download

        download(chkpoint, base_dir)
        assert os.path.exists(manifest_path)

    manifest = open(manifest_path)
    variables = json.load(manifest)
    manifest.close()

    state_dict = {}
    for x in variables:
        torch_name = to_torch_name(x)
        if not torch_name:
            continue
        filename = variables[x]["filename"]
        byte = open(os.path.join(base_dir, chkpoint, filename), "rb").read()
        fmt = str(int(len(byte) / struct.calcsize("f"))) + "f"
        d = struct.unpack(fmt, byte)
        d = np.array(d, dtype=np.float32)
        shape = variables[x]["shape"]
        if len(shape) == 4:
            tpt = (2, 3, 0, 1) if "depthwise" in filename else (3, 2, 0, 1)
            d = np.reshape(d, shape).transpose(tpt)
        state_dict[torch_name] = torch.Tensor(d)

    return state_dict


def _read_imgfile(path, width, height):
    img = cv2.imread(path)
    img = cv2.resize(img, (width, height))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)
    img = img * (2.0 / 255.0) - 1.0
    img = img.transpose((2, 0, 1))
    return img


def convert(model_id, model_dir, output_stride=16, image_size=513, check=True):
    checkpoint_name = MOBILENET_V1_CHECKPOINTS[model_id]
    width = image_size
    height = image_size

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    state_dict = load_variables(checkpoint_name)
    m = MobileNetV1(model_id, output_stride=output_stride)
    m.load_state_dict(state_dict)
    checkpoint_path = os.path.join(model_dir, checkpoint_name) + ".pth"
    torch.save(m.state_dict(), checkpoint_path)

    if check and os.path.exists("./images/tennis_in_crowd.jpg"):
        # Result
        input_image = _read_imgfile("./images/tennis_in_crowd.jpg", width, height)
        input_image = np.array(input_image, dtype=np.float32)
        input_image = input_image.reshape(1, 3, height, width)
        input_image = torch.Tensor(input_image)

        heatmaps_result, offset_result, displacement_fwd_result, displacement_bwd_result = m(input_image)

        print("Heatmaps")
        print(heatmaps_result.shape)
        print(heatmaps_result[:, 0:1, 0:1])
        print(torch.mean(heatmaps_result))