import sys, os
sys.path.append(os.path.abspath(".."))
import torch
import cv2
import numpy as np
import pandas as pd
import argparse
import models
import yaml
from tqdm import tqdm
from glob import glob
parser = argparse.ArgumentParser(description='Structure from Motion Learner',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--weight_pth', required=True, help="path to disp model's pth file")
parser.add_argument('--path_to_yaml', default="../datasets/data_for_SC_SfMLearner/environment.yaml", help="path to environent")
parser.add_argument('--use_camera', type=int, default=-9999, help="カメラをq使うならポート番号を指定")
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
with open(args.path_to_yaml) as f:
environment = yaml.safe_load(f)
r_h, r_w = environment["dataset_info"]["height"], environment["dataset_info"]["width"]
pass_num = environment["dataset_info"]["frame_freq"]
def cvmat2tensor(cvmat):
cvmat = cv2.cvtColor(cvmat, cv2.COLOR_BGR2RGB)
cvmat = np.transpose(cvmat, (2, 0, 1)).astype(np.float32)
input_tensor = torch.from_numpy(cvmat).unsqueeze(0)
input_tensor = ((input_tensor / 255 - 0.45) / 0.225).to(device)
return input_tensor
def run_inference_pose(cap, save_name, pose_net):
print("start processing => {}".format(save_name))
ret, prev_frame = cap.read()
prev_frame = cv2.resize(prev_frame, (r_w, r_h), interpolation=cv2.INTER_LINEAR)
prev_tensor = cvmat2tensor(prev_frame)
iter_num = 0
save_poses = list()
while True:
ret, cur_frame = cap.read()
if not ret:
break
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cur_frame = cv2.resize(cur_frame, (r_w, r_h), interpolation=cv2.INTER_LINEAR)
if iter_num % pass_num == (pass_num - 1):
cur_tensor = cvmat2tensor(cur_frame)
poses = pose_net(cur_tensor, prev_tensor)[0]
poses = [item.cpu().item() for item in list(poses)]
print(poses)
save_poses.append(poses)
prev_tensor = cur_tensor
cv2.imshow("current frame", cur_frame)
cv2.waitKey(1)
iter_num += 1
print("finish processing => {}".format(save_name))
save_df = pd.DataFrame(save_poses)
save_df.columns = ["tx", "ty", "tz", "Rx", "Ry", "Rz"]
save_df.to_csv("./output/{}.csv".format(save_name), index=False)
pose_net = models.PoseResNet(18, False).to(device)
pose_net = torch.nn.DataParallel(pose_net)
weights = torch.load(args.weight_pth)
pose_net.load_state_dict(weights)
pose_net.eval()
print("カメラのポーズは[tx, ty, tz, Rx, Ry, Rz]のように表示されます.")
if args.use_camera != -9999:
print("「q」キーを押すことでプログラムが終了します.")
cap = cv2.VideoCapture(args.use_camera)
run_inference_pose(cap, "usb", pose_net)
else:
target_file_list = glob(os.path.join("input", "*.*"))
for target_file in tqdm(target_file_list):
ext = target_file.lower().split('.')[-1]
if ext == "mp4":
cap = cv2.VideoCapture(target_file)
save_name = os.path.splitext(os.path.basename(target_file))[0]
run_inference_pose(cap, save_name, pose_net)