import sys, os
sys.path.append(os.path.abspath(".."))
import torch
import cv2
import numpy as np
import models
import argparse
from utils import tensor2array
from glob import glob
import os.path as osp
import yaml
from tqdm import tqdm
parser = argparse.ArgumentParser(description='Structure from Motion Learner',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--weight_pth', required=True, help="path to disp model's pth file")
parser.add_argument('--path_to_yaml', default="../datasets/data_for_SC_SfMLearner/environment.yaml", help="path to environent")
parser.add_argument('--use_camera', type=int, default=-9999, help="カメラを使うならポート番号を指定")
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
disp_net = models.DispResNet(18, False).to(device)
# disp_net = torch.nn.DataParallel(disp_net)
weights = torch.load(args.weight_pth)
disp_net.load_state_dict(weights)
disp_net.eval()
with open(args.path_to_yaml) as f:
environment = yaml.safe_load(f)
r_h, r_w = environment["dataset_info"]["height"], environment["dataset_info"]["width"]
if args.use_camera != -9999:
cap = cv2.VideoCapture(args.use_camera)
w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print("「q」キーを押すことでプログラムが終了します.")
while True:
ret, input_array = cap.read()
cv2.imshow("src", input_array)
if not ret:
break
if cv2.waitKey(1) & 0xFF == ord('q'):
break
input_array = cv2.resize(input_array, (r_w, r_h), interpolation=cv2.INTER_LINEAR)
input_array = cv2.cvtColor(input_array, cv2.COLOR_BGR2RGB)
input_array = np.transpose(input_array, (2, 0, 1)).astype(np.float32)
input_tensor = torch.from_numpy(input_array).unsqueeze(0)
input_tensor = ((input_tensor / 255 - 0.45) / 0.225).to(device)
output = disp_net(input_tensor)[0]
disp = (255 * tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8)
disp = np.transpose(disp, (1, 2, 0))
disp = cv2.resize(disp, (w, h), interpolation=cv2.INTER_LINEAR)
cv2.imshow("disp", disp)
cv2.waitKey(1)
cap.release()
else:
target_files_list = glob(osp.join("input", "*.*"))
img_ext = ["jpg", "png"]
video_ext = ["mp4"]
for target_file in tqdm(target_files_list):
ext = target_file.lower().split('.')[-1]
if ext in img_ext:
input_array = cv2.imread(target_file)
h, w, _ = input_array.shape
input_array = cv2.resize(input_array, (r_w, r_h), cv2.INTER_LINEAR)
input_array = cv2.cvtColor(input_array, cv2.COLOR_BGR2RGB)
input_array = np.transpose(input_array, (2, 0, 1)).astype(np.float32)
input_tensor = torch.from_numpy(input_array).unsqueeze(0)
input_tensor = ((input_tensor/255 - 0.45)/0.225).to(device)
output = disp_net(input_tensor)[0]
disp = (255*tensor2array(output, max_value=None)).astype(np.uint8)
disp = np.transpose(disp, (1, 2, 0))
disp = cv2.resize(disp, (w, h), interpolation=cv2.INTER_LINEAR)
cv2.imwrite(osp.join("./output", osp.basename(target_file)), disp)
elif ext in video_ext:
cap = cv2.VideoCapture(target_file)
w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
writer = cv2.VideoWriter(osp.join("./output", osp.basename(target_file)), fourcc, fps, (w, h))
while True:
ret, input_array = cap.read()
if not ret:
break
input_array = cv2.resize(input_array, (r_w, r_h), interpolation=cv2.INTER_LINEAR)
input_array = cv2.cvtColor(input_array, cv2.COLOR_BGR2RGB)
input_array = np.transpose(input_array, (2, 0, 1)).astype(np.float32)
input_tensor = torch.from_numpy(input_array).unsqueeze(0)
input_tensor = ((input_tensor / 255 - 0.45) / 0.225).to(device)
output = disp_net(input_tensor)[0]
disp = (255 * tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8)
disp = np.transpose(disp, (1, 2, 0))
disp = cv2.resize(disp, (w, h), interpolation=cv2.INTER_LINEAR)
writer.write(cv2.cvtColor(disp, cv2.COLOR_GRAY2BGR))
cap.release()
writer.release()