Newer
Older
MyEndoSfMLearner / prediction / predict_pose.py
@planck planck on 17 Nov 2020 2 KB 最初のコミット
import torch
import cv2

import torch
import cv2
import numpy as np
import networks
import pandas as pd

from PIL import Image
from torchvision import transforms

pth_base_path = r"D:\Deep_Learning\SC_SfMLearner\esophagus\log\mdp(SC-SfMLearner_2回目)\models\weights_99\{}.pth"
video_path = r"D:\Deep_Learning\MonoDepth2\esophagus\movies\trimed\0.mp4"
out_path = r"D:\Deep_Learning\SC_SfMLearner\esophagus\pred_poses\pred_pose{}.csv"
pass_num = 30

cap = cv2.VideoCapture(video_path)
scaled_intrinsic = np.load("./params/intrinsics_scaled.npy")
dist_coeffs = np.load("./params/dist_coeffs.npy")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
models = {}

models["pose_encoder"] = networks.ResnetEncoder(18, False, 2)
models["pose_encoder"].to(device)
models["pose"] = networks.PoseDecoder(models["pose_encoder"].num_ch_enc, 1, 2)
models["pose"].to(device)

for n in ["pose_encoder", "pose"]:
    print("Loading {} weights...".format(n))
    path = pth_base_path.format(n)
    model_dict = models[n].state_dict()
    pretrained_dict = torch.load(path)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    models[n].load_state_dict(model_dict)
    models[n].eval()

def trim_endo(eso_img):
    eso_img = eso_img[32:989, 323:1599, :]
    eso_img = cv2.resize(eso_img, (480, 352), interpolation=cv2.INTER_LINEAR)
    eso_img = cv2.undistort(eso_img, scaled_intrinsic, dist_coeffs)
    return eso_img

def to_endo_tensor(eso_img):
    eso_img = eso_img[32:989, 323:1599, :]
    eso_img = cv2.resize(eso_img, (480, 352), interpolation=cv2.INTER_LINEAR)
    eso_img = cv2.undistort(eso_img, scaled_intrinsic, dist_coeffs)
    pil_img = Image.fromarray(cv2.cvtColor(eso_img, cv2.COLOR_BGR2RGB))
    tensor = transforms.ToTensor()(pil_img).unsqueeze(0).to(device)

    return tensor

ret, frame = cap.read()
prev_frame = trim_endo(frame)
prev_tensor = to_endo_tensor(frame)

csv_num = 0
iter_num = 0
while True:
    ret, frame = cap.read()

    if not ret:
        break

    if iter_num % pass_num == (pass_num - 1):
        cur_frame = trim_endo(frame)
        cv2.imshow("cur_frame", cur_frame)
        cv2.imshow("prev_frame", prev_frame)
        cur_tensor = to_endo_tensor(frame)
        pose_features = models["pose_encoder"](torch.cat((prev_tensor, cur_tensor), 1))
        poses = models["pose"]([pose_features])[0]
        poses = [item.cpu().item() for item in list(poses)]
        print(poses)
        pd.DataFrame([poses]).to_csv(out_path.format(csv_num), header=None, columns=None)
        csv_num += 1

        prev_tensor = cur_tensor
        prev_frame = cur_frame
        cv2.waitKey(1)
    iter_num += 1