import json
import struct
import cv2
import numpy as np
import os
import tempfile
import torch
from modules.posenet import MobileNetV1, MOBILENET_V1_CHECKPOINTS
BASE_DIR = os.path.join(tempfile.gettempdir(), "_posenet_weights")
def to_torch_name(tf_name):
tf_name = tf_name.lower()
tf_split = tf_name.split("/")
tf_layer_split = tf_split[1].split("_")
tf_variable_type = tf_split[2]
if tf_variable_type == "weights" or tf_variable_type == "depthwise_weights":
variable_postfix = ".weight"
elif tf_variable_type == "biases":
variable_postfix = ".bias"
else:
variable_postfix = ""
if tf_layer_split[0] == "conv2d":
torch_name = "features.conv" + tf_layer_split[1]
if len(tf_layer_split) > 2:
torch_name += "." + tf_layer_split[2]
else:
torch_name += ".conv"
torch_name += variable_postfix
else:
if tf_layer_split[0] in ["offset", "displacement", "heatmap"] and tf_layer_split[-1] == "2":
torch_name = "_".join(tf_layer_split[:-1])
torch_name += variable_postfix
else:
torch_name = ""
return torch_name
def load_variables(chkpoint, base_dir=BASE_DIR):
manifest_path = os.path.join(base_dir, chkpoint, "manifest.json")
if not os.path.exists(manifest_path):
print("Weights for checkpoint %s are not downloaded. Downloading to %s ..." % (chkpoint, base_dir))
from modules.posenet.converter.wget import download
download(chkpoint, base_dir)
assert os.path.exists(manifest_path)
manifest = open(manifest_path)
variables = json.load(manifest)
manifest.close()
state_dict = {}
for x in variables:
torch_name = to_torch_name(x)
if not torch_name:
continue
filename = variables[x]["filename"]
byte = open(os.path.join(base_dir, chkpoint, filename), "rb").read()
fmt = str(int(len(byte) / struct.calcsize("f"))) + "f"
d = struct.unpack(fmt, byte)
d = np.array(d, dtype=np.float32)
shape = variables[x]["shape"]
if len(shape) == 4:
tpt = (2, 3, 0, 1) if "depthwise" in filename else (3, 2, 0, 1)
d = np.reshape(d, shape).transpose(tpt)
state_dict[torch_name] = torch.Tensor(d)
return state_dict
def _read_imgfile(path, width, height):
img = cv2.imread(path)
img = cv2.resize(img, (width, height))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32)
img = img * (2.0 / 255.0) - 1.0
img = img.transpose((2, 0, 1))
return img
def convert(model_id, model_dir, output_stride=16, image_size=513, check=True):
checkpoint_name = MOBILENET_V1_CHECKPOINTS[model_id]
width = image_size
height = image_size
if not os.path.exists(model_dir):
os.makedirs(model_dir)
state_dict = load_variables(checkpoint_name)
m = MobileNetV1(model_id, output_stride=output_stride)
m.load_state_dict(state_dict)
checkpoint_path = os.path.join(model_dir, checkpoint_name) + ".pth"
torch.save(m.state_dict(), checkpoint_path)
if check and os.path.exists("./images/tennis_in_crowd.jpg"):
# Result
input_image = _read_imgfile("./images/tennis_in_crowd.jpg", width, height)
input_image = np.array(input_image, dtype=np.float32)
input_image = input_image.reshape(1, 3, height, width)
input_image = torch.Tensor(input_image)
heatmaps_result, offset_result, displacement_fwd_result, displacement_bwd_result = m(input_image)
print("Heatmaps")
print(heatmaps_result.shape)
print(heatmaps_result[:, 0:1, 0:1])
print(torch.mean(heatmaps_result))