diff --git a/datasets/make_datasets.py b/datasets/make_datasets.py index c5b63b0..1194d9a 100644 --- a/datasets/make_datasets.py +++ b/datasets/make_datasets.py @@ -137,7 +137,8 @@ with open(osp.join(options.out_dir, "environment.yaml"), "w") as f: dataset_info = {"height": options.save_height, - "width": options.save_width} + "width": options.save_width, + "frame_freq": options.save_frequency} camera_info = {"intrinsic": intrinsic} environment = {"dataset_info": dataset_info, "camera_info": camera_info} diff --git a/prediction/pred_disp.py b/prediction/pred_disp.py index 443efd2..f091123 100644 --- a/prediction/pred_disp.py +++ b/prediction/pred_disp.py @@ -23,7 +23,7 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") disp_net = models.DispResNet(18, False).to(device) -disp_net = torch.nn.DataParallel(disp_net) +# disp_net = torch.nn.DataParallel(disp_net) weights = torch.load(args.weight_pth) disp_net.load_state_dict(weights) disp_net.eval() diff --git a/prediction/pred_pose.py b/prediction/pred_pose.py index e170828..f6aa77a 100644 --- a/prediction/pred_pose.py +++ b/prediction/pred_pose.py @@ -8,11 +8,88 @@ import pandas as pd import argparse import models +import yaml +from tqdm import tqdm +from glob import glob parser = argparse.ArgumentParser(description='Structure from Motion Learner', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--weight_pth', required=True, help="path to disp model's pth file") parser.add_argument('--path_to_yaml', default="../datasets/data_for_SC_SfMLearner/environment.yaml", help="path to environent") -parser.add_argument('--use_camera', type=int, default=-9999, help="カメラを使うならポート番号を指定") +parser.add_argument('--use_camera', type=int, default=-9999, help="カメラをq使うならポート番号を指定") args = parser.parse_args() + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +with open(args.path_to_yaml) as f: + environment = yaml.safe_load(f) +r_h, r_w = environment["dataset_info"]["height"], environment["dataset_info"]["width"] +pass_num = environment["dataset_info"]["frame_freq"] + +def cvmat2tensor(cvmat): + cvmat = cv2.cvtColor(cvmat, cv2.COLOR_BGR2RGB) + cvmat = np.transpose(cvmat, (2, 0, 1)).astype(np.float32) + input_tensor = torch.from_numpy(cvmat).unsqueeze(0) + input_tensor = ((input_tensor / 255 - 0.45) / 0.225).to(device) + + return input_tensor + +def run_inference_pose(cap, save_name, pose_net): + print("start processing => {}".format(save_name)) + ret, prev_frame = cap.read() + prev_frame = cv2.resize(prev_frame, (r_w, r_h), interpolation=cv2.INTER_LINEAR) + prev_tensor = cvmat2tensor(prev_frame) + + iter_num = 0 + save_poses = list() + while True: + ret, cur_frame = cap.read() + + if not ret: + break + + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + cur_frame = cv2.resize(cur_frame, (r_w, r_h), interpolation=cv2.INTER_LINEAR) + if iter_num % pass_num == (pass_num - 1): + cur_tensor = cvmat2tensor(cur_frame) + poses = pose_net(cur_tensor, prev_tensor)[0] + poses = [item.cpu().item() for item in list(poses)] + print(poses) + save_poses.append(poses) + prev_tensor = cur_tensor + + cv2.imshow("current frame", cur_frame) + cv2.waitKey(1) + iter_num += 1 + + print("finish processing => {}".format(save_name)) + save_df = pd.DataFrame(save_poses) + save_df.columns = ["tx", "ty", "tz", "Rx", "Ry", "Rz"] + save_df.to_csv("./output/{}.csv".format(save_name), index=False) + + + +pose_net = models.PoseResNet(18, False).to(device) +pose_net = torch.nn.DataParallel(pose_net) +weights = torch.load(args.weight_pth) +pose_net.load_state_dict(weights) +pose_net.eval() +print("カメラのポーズは[tx, ty, tz, Rx, Ry, Rz]のように表示されます.") + +if args.use_camera != -9999: + print("「q」キーを押すことでプログラムが終了します.") + cap = cv2.VideoCapture(args.use_camera) + run_inference_pose(cap, "usb", pose_net) +else: + target_file_list = glob(os.path.join("input", "*.*")) + + for target_file in tqdm(target_file_list): + ext = target_file.lower().split('.')[-1] + + if ext == "mp4": + cap = cv2.VideoCapture(target_file) + save_name = os.path.splitext(os.path.basename(target_file))[0] + run_inference_pose(cap, save_name, pose_net) diff --git a/requirements.txt b/requirements.txt index d81625f..67684cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ PyYAML==5.3.1 tensorboardX==2.1 +tensorboard==2.3.0 opencv-python-headless==4.4.0.40 path==15.0.0 matplotlib==3.3.1 diff --git a/train.py b/train.py index 8b4c38c..5cf6bb4 100644 --- a/train.py +++ b/train.py @@ -85,6 +85,7 @@ environment = yaml.safe_load(f) timestamp = Path(datetime.datetime.now().strftime("%m-%d-%H-%M")) + os.makedirs("./checkpoints", exist_ok=True) args.save_path = 'checkpoints' / timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p()