Newer
Older
SC-SfMLearner_for_NLab / prediction / pred_pose.py
@planck planck on 7 Dec 2020 3 KB predictionの更新
import sys, os

sys.path.append(os.path.abspath(".."))

import torch
import cv2
import numpy as np
import pandas as pd
import argparse
import models
import yaml
from tqdm import tqdm
from glob import glob

parser = argparse.ArgumentParser(description='Structure from Motion Learner',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--weight_pth', required=True, help="path to disp model's pth file")
parser.add_argument('--path_to_yaml', default="../datasets/data_for_SC_SfMLearner/environment.yaml", help="path to environent")
parser.add_argument('--use_camera', type=int, default=-9999, help="カメラをq使うならポート番号を指定")
args = parser.parse_args()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

with open(args.path_to_yaml) as f:
    environment = yaml.safe_load(f)
r_h, r_w = environment["dataset_info"]["height"], environment["dataset_info"]["width"]
pass_num = environment["dataset_info"]["frame_freq"]

def cvmat2tensor(cvmat):
    cvmat = cv2.cvtColor(cvmat, cv2.COLOR_BGR2RGB)
    cvmat = np.transpose(cvmat, (2, 0, 1)).astype(np.float32)
    input_tensor = torch.from_numpy(cvmat).unsqueeze(0)
    input_tensor = ((input_tensor / 255 - 0.45) / 0.225).to(device)

    return input_tensor

def run_inference_pose(cap, save_name, pose_net):
    print("start processing => {}".format(save_name))
    ret, prev_frame = cap.read()
    prev_frame = cv2.resize(prev_frame, (r_w, r_h), interpolation=cv2.INTER_LINEAR)
    prev_tensor = cvmat2tensor(prev_frame)

    iter_num = 0
    save_poses = list()
    while True:
        ret, cur_frame = cap.read()

        if not ret:
            break

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

        cur_frame = cv2.resize(cur_frame, (r_w, r_h), interpolation=cv2.INTER_LINEAR)
        if iter_num % pass_num == (pass_num - 1):
            cur_tensor = cvmat2tensor(cur_frame)
            poses = pose_net(cur_tensor, prev_tensor)[0]
            poses = [item.cpu().item() for item in list(poses)]
            print(poses)
            save_poses.append(poses)
            prev_tensor = cur_tensor

        cv2.imshow("current frame", cur_frame)
        cv2.waitKey(1)
        iter_num += 1

    print("finish processing => {}".format(save_name))
    save_df = pd.DataFrame(save_poses)
    save_df.columns = ["tx", "ty", "tz", "Rx", "Ry", "Rz"]
    save_df.to_csv("./output/{}.csv".format(save_name), index=False)



pose_net = models.PoseResNet(18, False).to(device)
# pose_net = torch.nn.DataParallel(pose_net)
weights = torch.load(args.weight_pth)
pose_net.load_state_dict(weights)
pose_net.eval()
print("カメラのポーズは[tx, ty, tz, Rx, Ry, Rz]のように表示されます.")

if args.use_camera != -9999:
    print("「q」キーを押すことでプログラムが終了します.")
    cap = cv2.VideoCapture(args.use_camera)
    run_inference_pose(cap, "usb", pose_net)
else:
    target_file_list = glob(os.path.join("input", "*.*"))

    for target_file in tqdm(target_file_list):
        ext = target_file.lower().split('.')[-1]

        if ext == "mp4":
            cap = cv2.VideoCapture(target_file)
            save_name = os.path.splitext(os.path.basename(target_file))[0]
            run_inference_pose(cap, save_name, pose_net)