diff --git a/.gitignore b/.gitignore index 46ee510..e446405 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,6 @@ checkpoints/* datasets/data_for_SC_SfMLearner datasets/data_for_SC_SfMLearner/* -datasets/train_videos/* -datasets/val_videos/* datasets/__pycache__/* datasets_module/__pycache__/* models/__pycache__/* diff --git a/datasets/train_videos/example.mp4 b/datasets/train_videos/example.mp4 new file mode 100644 index 0000000..f8962f6 --- /dev/null +++ b/datasets/train_videos/example.mp4 Binary files differ diff --git a/datasets/val_videos/example.mp4 b/datasets/val_videos/example.mp4 new file mode 100644 index 0000000..f8962f6 --- /dev/null +++ b/datasets/val_videos/example.mp4 Binary files differ diff --git a/datasets_module/sequence_folders.py b/datasets_module/sequence_folders.py index dae3498..8303b54 100644 --- a/datasets_module/sequence_folders.py +++ b/datasets_module/sequence_folders.py @@ -22,7 +22,7 @@ transform functions must take in a list a images and a numpy array (usually intrinsics matrix) """ - def __init__(self, root, seed=None, train=True, sequence_length=3, transform=None, skip_frames=1, dataset='kitti'): + def __init__(self, root, seed=None, train=True, sequence_length=3, transform=None, skip_frames=1, dataset='kitti', environment=None): np.random.seed(seed) random.seed(seed) self.root = Path(root) @@ -32,8 +32,10 @@ self.transform = transform self.dataset = dataset self.k = skip_frames + self.environment = environment self.crawl_folders(sequence_length) + try: self.brightness = (0.8, 1.2) self.contrast = (0.8, 1.2) @@ -54,7 +56,7 @@ shifts = list(range(-demi_length * self.k, demi_length * self.k + 1, self.k)) shifts.pop(demi_length) for scene in self.scenes: - intrinsics = np.array([[490.135, 0, 142.842], [0, 492.224, 253.201], [0, 0, 1]]).astype(np.float32) + intrinsics = np.array(self.environment["camera_info"]["intrinsic"]).astype(np.float32) # intrinsics = np.genfromtxt(scene/'cam.txt').astype(np.float32).reshape((3, 3)) imgs = sorted(scene.files('*.jpg')) diff --git a/prediction/input/00001614.jpg b/prediction/input/00001614.jpg new file mode 100644 index 0000000..628e0aa --- /dev/null +++ b/prediction/input/00001614.jpg Binary files differ diff --git a/prediction/input/test.mp4 b/prediction/input/test.mp4 new file mode 100644 index 0000000..f8962f6 --- /dev/null +++ b/prediction/input/test.mp4 Binary files differ diff --git a/prediction/output/00001614.jpg b/prediction/output/00001614.jpg new file mode 100644 index 0000000..516ea0b --- /dev/null +++ b/prediction/output/00001614.jpg Binary files differ diff --git a/prediction/output/test.mp4 b/prediction/output/test.mp4 new file mode 100644 index 0000000..575fc79 --- /dev/null +++ b/prediction/output/test.mp4 Binary files differ diff --git a/prediction/pred_disp.py b/prediction/pred_disp.py new file mode 100644 index 0000000..443efd2 --- /dev/null +++ b/prediction/pred_disp.py @@ -0,0 +1,121 @@ +import sys, os + +sys.path.append(os.path.abspath("..")) + +import torch +import cv2 +import numpy as np +import models +import argparse +from utils import tensor2array +from glob import glob +import os.path as osp +import yaml +from tqdm import tqdm + +parser = argparse.ArgumentParser(description='Structure from Motion Learner', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument('--weight_pth', required=True, help="path to disp model's pth file") +parser.add_argument('--path_to_yaml', default="../datasets/data_for_SC_SfMLearner/environment.yaml", help="path to environent") +parser.add_argument('--use_camera', type=int, default=-9999, help="カメラを使うならポート番号を指定") +args = parser.parse_args() + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +disp_net = models.DispResNet(18, False).to(device) +disp_net = torch.nn.DataParallel(disp_net) +weights = torch.load(args.weight_pth) +disp_net.load_state_dict(weights) +disp_net.eval() + +with open(args.path_to_yaml) as f: + environment = yaml.safe_load(f) +r_h, r_w = environment["dataset_info"]["height"], environment["dataset_info"]["width"] + +if args.use_camera != -9999: + cap = cv2.VideoCapture(args.use_camera) + w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + print("「q」キーを押すことでプログラムが終了します.") + + while True: + ret, input_array = cap.read() + cv2.imshow("src", input_array) + + if not ret: + break + + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + input_array = cv2.resize(input_array, (r_w, r_h), interpolation=cv2.INTER_LINEAR) + input_array = cv2.cvtColor(input_array, cv2.COLOR_BGR2RGB) + input_array = np.transpose(input_array, (2, 0, 1)).astype(np.float32) + + input_tensor = torch.from_numpy(input_array).unsqueeze(0) + input_tensor = ((input_tensor / 255 - 0.45) / 0.225).to(device) + + output = disp_net(input_tensor)[0] + + disp = (255 * tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8) + disp = np.transpose(disp, (1, 2, 0)) + disp = cv2.resize(disp, (w, h), interpolation=cv2.INTER_LINEAR) + cv2.imshow("disp", disp) + cv2.waitKey(1) + + cap.release() + +else: + target_files_list = glob(osp.join("input", "*.*")) + img_ext = ["jpg", "png"] + video_ext = ["mp4"] + + for target_file in tqdm(target_files_list): + ext = target_file.lower().split('.')[-1] + + if ext in img_ext: + input_array = cv2.imread(target_file) + h, w, _ = input_array.shape + input_array = cv2.resize(input_array, (r_w, r_h), cv2.INTER_LINEAR) + input_array = cv2.cvtColor(input_array, cv2.COLOR_BGR2RGB) + input_array = np.transpose(input_array, (2, 0, 1)).astype(np.float32) + + input_tensor = torch.from_numpy(input_array).unsqueeze(0) + input_tensor = ((input_tensor/255 - 0.45)/0.225).to(device) + + output = disp_net(input_tensor)[0] + + disp = (255*tensor2array(output, max_value=None)).astype(np.uint8) + disp = np.transpose(disp, (1, 2, 0)) + disp = cv2.resize(disp, (w, h), interpolation=cv2.INTER_LINEAR) + cv2.imwrite(osp.join("./output", osp.basename(target_file)), disp) + + elif ext in video_ext: + cap = cv2.VideoCapture(target_file) + + w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') + writer = cv2.VideoWriter(osp.join("./output", osp.basename(target_file)), fourcc, fps, (w, h)) + + while True: + ret, input_array = cap.read() + + if not ret: + break + + input_array = cv2.resize(input_array, (r_w, r_h), interpolation=cv2.INTER_LINEAR) + input_array = cv2.cvtColor(input_array, cv2.COLOR_BGR2RGB) + input_array = np.transpose(input_array, (2, 0, 1)).astype(np.float32) + + input_tensor = torch.from_numpy(input_array).unsqueeze(0) + input_tensor = ((input_tensor / 255 - 0.45) / 0.225).to(device) + + output = disp_net(input_tensor)[0] + + disp = (255 * tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8) + disp = np.transpose(disp, (1, 2, 0)) + disp = cv2.resize(disp, (w, h), interpolation=cv2.INTER_LINEAR) + writer.write(cv2.cvtColor(disp, cv2.COLOR_GRAY2BGR)) + + cap.release() + writer.release() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d81625f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +PyYAML==5.3.1 +tensorboardX==2.1 +opencv-python-headless==4.4.0.40 +path==15.0.0 +matplotlib==3.3.1 +imageio==2.9.0 +tqdm==4.48.2 diff --git a/train.py b/train.py index fa84a95..8b4c38c 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ from loss_functions import compute_smooth_loss, compute_photo_and_geometry_loss, compute_errors from tensorboardX import SummaryWriter import os +import yaml parser = argparse.ArgumentParser(description='Structure from Motion Learner training on KITTI and CityScapes Dataset', formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -80,6 +81,8 @@ def main(): global best_error, n_iter, device args = parser.parse_args() + with open(os.path.join(args.data, "environment.yaml")) as f: + environment = yaml.safe_load(f) timestamp = Path(datetime.datetime.now().strftime("%m-%d-%H-%M")) args.save_path = 'checkpoints' / timestamp @@ -111,6 +114,7 @@ valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize]) + print("=> fetching scenes in '{}'".format(args.data)) if args.folder_type == 'sequence': train_set = SequenceFolder( @@ -119,7 +123,8 @@ seed=args.seed, train=True, sequence_length=args.sequence_length, - dataset=args.dataset + dataset=args.dataset, + environment=environment ) else: train_set = PairFolder( @@ -144,7 +149,8 @@ seed=args.seed, train=False, sequence_length=args.sequence_length, - dataset=args.dataset + dataset=args.dataset, + environment=environment ) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))