diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..46ee510 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.idea/* +__pycache__/* +checkpoints/* +datasets/data_for_SC_SfMLearner +datasets/data_for_SC_SfMLearner/* +datasets/train_videos/* +datasets/val_videos/* +datasets/__pycache__/* +datasets_module/__pycache__/* +models/__pycache__/* diff --git a/custom_transforms.py b/custom_transforms.py new file mode 100644 index 0000000..9c9a9c5 --- /dev/null +++ b/custom_transforms.py @@ -0,0 +1,84 @@ +from __future__ import division +import torch +import random +import numpy as np +from PIL import Image + +'''Set of tranform random routines that takes list of inputs as arguments, +in order to have random but coherent transformations.''' + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, images, intrinsics): + for t in self.transforms: + images, intrinsics = t(images, intrinsics) + return images, intrinsics + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, images, intrinsics): + for tensor in images: + for t, m, s in zip(tensor, self.mean, self.std): + t.sub_(m).div_(s) + return images, intrinsics + + +class ArrayToTensor(object): + """Converts a list of numpy.ndarray (H x W x C) along with a intrinsics matrix to a list of torch.FloatTensor of shape (C x H x W) with a intrinsics tensor.""" + + def __call__(self, images, intrinsics): + tensors = [] + for im in images: + # put it from HWC to CHW format + im = np.transpose(im, (2, 0, 1)) + # handle numpy array + tensors.append(torch.from_numpy(im).float()/255) + return tensors, intrinsics + + +class RandomHorizontalFlip(object): + """Randomly horizontally flips the given numpy array with a probability of 0.5""" + + def __call__(self, images, intrinsics): + assert intrinsics is not None + if random.random() < 0.5: + output_intrinsics = np.copy(intrinsics) + output_images = [np.copy(np.fliplr(im)) for im in images] + w = output_images[0].shape[1] + output_intrinsics[0, 2] = w - output_intrinsics[0, 2] + else: + output_images = images + output_intrinsics = intrinsics + return output_images, output_intrinsics + + +class RandomScaleCrop(object): + """Randomly zooms images up to 15% and crop them to keep same size as before.""" + + def __call__(self, images, intrinsics): + assert intrinsics is not None + output_intrinsics = np.copy(intrinsics) + + in_h, in_w, _ = images[0].shape + x_scaling, y_scaling = np.random.uniform(1, 1.15, 2) + scaled_h, scaled_w = int(in_h * y_scaling), int(in_w * x_scaling) + + output_intrinsics[0] *= x_scaling + output_intrinsics[1] *= y_scaling + scaled_images = [np.array(Image.fromarray(im.astype(np.uint8)).resize((scaled_w, scaled_h))).astype(np.float32) for im in images] + + offset_y = np.random.randint(scaled_h - in_h + 1) + offset_x = np.random.randint(scaled_w - in_w + 1) + cropped_images = [im[offset_y:offset_y + in_h, offset_x:offset_x + in_w] for im in scaled_images] + + output_intrinsics[0, 2] -= offset_x + output_intrinsics[1, 2] -= offset_y + + return cropped_images, output_intrinsics diff --git a/datasets/make_datasets.py b/datasets/make_datasets.py new file mode 100644 index 0000000..c5b63b0 --- /dev/null +++ b/datasets/make_datasets.py @@ -0,0 +1,149 @@ +import numpy as np +import cv2 +import os +import os.path as osp +import argparse +import tkinter as tk +import yaml +from glob import glob + +file_dir = os.path.dirname(__file__) + +parser = argparse.ArgumentParser() + +parser.add_argument("--out_dir", + type=str, + help="データセットの出力先", + default=osp.join(file_dir, "data_for_SC_SfMLearner")) + +parser.add_argument("--save_frequency", + type=int, + help="動画からどれくらいの周期でフレームを画像として保存するかの指定", + default=10) + +parser.add_argument("--no_make_val", + help="評価用データを作成するか否か", + action="store_true") + +parser.add_argument("--save_height", + type=int, + help="画像データセットに変換する際のリサイズ後の画像の高さ.32の倍数でないといけない", + default=512) + +parser.add_argument("--save_width", + type=int, + help="画像データセットに変換する際のリサイズ後の画像の幅.32の倍数でないといけない", + default=288) + +options = parser.parse_args() + +K = [] +root = tk.Tk() +root.geometry("250x250") +root.title("monodepth2 dataset GUI") + +entry_boxs = {} +label1 = tk.Label(text="Please input your camera's intrinsics") +label1.place(x=30, y=20) + +entry_num = 0 +init_x, init_y = 55, 60 +offset_x, offset_y = 50, 30 +for col in range(3): + for row in range(3): + cur_key = "entry{}".format(entry_num) + entry_boxs[cur_key] = tk.Entry(width=7) + entry_boxs[cur_key].place(x=(init_x + row * offset_x), y=(init_y + col * offset_y)) + if entry_num in [1, 3, 6, 7]: + entry_boxs[cur_key].insert(tk.END, "0") + entry_num += 1 +entry_boxs["entry0"].insert(tk.END, "f_x") +entry_boxs["entry2"].insert(tk.END, "c_x") +entry_boxs["entry4"].insert(tk.END, "f_y") +entry_boxs["entry5"].insert(tk.END, "c_y") +entry_boxs["entry8"].insert(tk.END, "1") + + +def end_tk_process(): + global root + global K + K = [float(entry_boxs[key].get()) for key in entry_boxs] + root.destroy() + + +ok_button = tk.Button(text="finish", command=end_tk_process) +ok_button.place(x=100, y=180) + +resized_calb_bin = tk.BooleanVar() +resized_calb_bin.set(False) +resized_calb_box = tk.Checkbutton(root, variable=resized_calb_bin, text="Is resized image's intrinsic") +resized_calb_box.place(x=45, y=150) + +root.mainloop() + +Is_resized_intrinsic = resized_calb_bin.get() + +def make_monodepth2_dataset(mode="train"): + assert mode in ["train", "val"], "function make_monodepth2_dataset's mode must be 'train' or 'val' " + assert options.save_height % 32 == 0, "'height' must be a multiple of 32" + assert options.save_width % 32 == 0, "'width' must be a multiple of 32" + + video_paths = glob("./{}_videos/*".format(mode)) + + assert len(video_paths) != 0, "ファイル'{}_videos'に動画ファイルが入っていません" + + for path in video_paths: + assert path[-4:] == '.mp4', "動画はmp4ファイルのみに対応しています" + + os.makedirs(options.out_dir, exist_ok=True) + os.makedirs(osp.join(options.out_dir, "{}".format(mode)), exist_ok=True) + + image_save_num = 0 + dataset_indication_list = [] + ful_res_w, ful_res_h = cv2.VideoCapture(video_paths[0]).get(cv2.CAP_PROP_FRAME_WIDTH), cv2.VideoCapture(video_paths[0]).get(cv2.CAP_PROP_FRAME_HEIGHT) + + intrinsic = [[K[0], K[1], K[2]], + [K[3], K[4], K[5]], + [K[6], K[7], K[8]]] + if not Is_resized_intrinsic: + intrinsic[0][0] *= options.save_width / ful_res_w + intrinsic[0][2] *= options.save_width / ful_res_w + intrinsic[1][1] *= options.save_height / ful_res_h + intrinsic[1][2] *= options.save_height / ful_res_h + for sequence_num, path in enumerate(video_paths): + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise RuntimeError("無効なmp4ファイルが発見されました.") + + dataset_indication_list.append("sequence_{}\n".format(sequence_num)) + save_dir = "sequence_{}".format(sequence_num) + os.makedirs(osp.join(options.out_dir, "{}".format(mode), save_dir), exist_ok=True) + while_count = 0 + while True: + ret, frame = cap.read() + + if not ret: + break + + if while_count % options.save_frequency == 0: + frame = cv2.resize(frame, (options.save_width, options.save_height), interpolation=cv2.INTER_LINEAR) + cv2.imwrite(osp.join(options.out_dir, "{}".format(mode), save_dir, "{:08}.jpg".format(image_save_num)), + frame) + image_save_num += 1 + while_count += 1 + + with open(osp.join(options.out_dir, "{}.txt".format(mode)), "w") as f: + f.writelines(dataset_indication_list) + + with open(osp.join(options.out_dir, "environment.yaml"), "w") as f: + dataset_info = {"height": options.save_height, + "width": options.save_width} + camera_info = {"intrinsic": intrinsic} + environment = {"dataset_info": dataset_info, + "camera_info": camera_info} + f.write(yaml.dump(environment)) + + +make_monodepth2_dataset(mode="train") +if not options.no_make_val: + make_monodepth2_dataset(mode="val") diff --git a/datasets_module/pair_folders.py b/datasets_module/pair_folders.py new file mode 100644 index 0000000..0fadb92 --- /dev/null +++ b/datasets_module/pair_folders.py @@ -0,0 +1,60 @@ +import torch.utils.data as data +import numpy as np +from imageio import imread +from path import Path +import random +import os + + +def load_as_float(path): + return imread(path).astype(np.float32) + + +class PairFolder(data.Dataset): + """A sequence data loader where the files are arranged in this way: + root/scene_1/0000000_0.jpg + root/scene_1/0000001_1.jpg + .. + root/scene_1/cam.txt + . + transform functions must take in a list a images and a numpy array (usually intrinsics matrix) + """ + + def __init__(self, root, seed=None, train=True, transform=None): + np.random.seed(seed) + random.seed(seed) + self.root = Path(root) + scene_list_path = self.root/'train.txt' if train else self.root/'val.txt' + self.scenes = [self.root/folder[:-1] for folder in open(scene_list_path)] + self.transform = transform + self.crawl_folders() + + def crawl_folders(self,): + pair_set = [] + for scene in self.scenes: + # intrinsics = np.genfromtxt(scene/'cam.txt').astype(np.float32).reshape((3, 3)) + + imgs = sorted(scene.files('*.jpg')) + intrinsics = sorted(scene.files('*.txt')) + + for i in range(0, len(imgs)-1, 2): + intrinsic = np.genfromtxt(intrinsics[int(i/2)]).astype(np.float32).reshape((3, 3)) + sample = {'intrinsics': intrinsic, 'tgt': imgs[i], 'ref_imgs': [imgs[i+1]]} + pair_set.append(sample) + random.shuffle(pair_set) + self.samples = pair_set + + def __getitem__(self, index): + sample = self.samples[index] + tgt_img = load_as_float(sample['tgt']) + ref_imgs = [load_as_float(ref_img) for ref_img in sample['ref_imgs']] + if self.transform is not None: + imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(sample['intrinsics'])) + tgt_img = imgs[0] + ref_imgs = imgs[1:] + else: + intrinsics = np.copy(sample['intrinsics']) + return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics) + + def __len__(self): + return len(self.samples) diff --git a/datasets_module/sequence_folders.py b/datasets_module/sequence_folders.py new file mode 100644 index 0000000..dae3498 --- /dev/null +++ b/datasets_module/sequence_folders.py @@ -0,0 +1,99 @@ +import torch.utils.data as data +import numpy as np +from imageio import imread +from path import Path +import random +import os +from torchvision import transforms + + +def load_as_float(path): + return imread(path).astype(np.float32) + + +class SequenceFolder(data.Dataset): + """A sequence data loader where the files are arranged in this way: + root/scene_1/0000000.jpg + root/scene_1/0000001.jpg + .. + root/scene_1/cam.txt + root/scene_2/0000000.jpg + . + transform functions must take in a list a images and a numpy array (usually intrinsics matrix) + """ + + def __init__(self, root, seed=None, train=True, sequence_length=3, transform=None, skip_frames=1, dataset='kitti'): + np.random.seed(seed) + random.seed(seed) + self.root = Path(root) + scene_list_path = self.root/'train.txt' if train else self.root/'val.txt' + self.scenes = [self.root/'train'/folder[:-1] for folder in open(scene_list_path)] if train else \ + [self.root / 'val' / folder[:-1] for folder in open(scene_list_path)] + self.transform = transform + self.dataset = dataset + self.k = skip_frames + self.crawl_folders(sequence_length) + + try: + self.brightness = (0.8, 1.2) + self.contrast = (0.8, 1.2) + self.saturation = (0.8, 1.2) + self.hue = (-0.1, 0.1) + transforms.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + except TypeError: + self.brightness = 0.2 + self.contrast = 0.2 + self.saturation = 0.2 + self.hue = 0.1 + + def crawl_folders(self, sequence_length): + # k skip frames + sequence_set = [] + demi_length = (sequence_length-1)//2 + shifts = list(range(-demi_length * self.k, demi_length * self.k + 1, self.k)) + shifts.pop(demi_length) + for scene in self.scenes: + intrinsics = np.array([[490.135, 0, 142.842], [0, 492.224, 253.201], [0, 0, 1]]).astype(np.float32) + # intrinsics = np.genfromtxt(scene/'cam.txt').astype(np.float32).reshape((3, 3)) + imgs = sorted(scene.files('*.jpg')) + + if len(imgs) < sequence_length: + continue + for i in range(demi_length * self.k, len(imgs)-demi_length * self.k): + sample = {'intrinsics': intrinsics, 'tgt': imgs[i], 'ref_imgs': []} + for j in shifts: + sample['ref_imgs'].append(imgs[i+j]) + sequence_set.append(sample) + random.shuffle(sequence_set) + self.samples = sequence_set + + def __getitem__(self, index): + ''' + do_color_aug = random.random() > 0.5 + do_flip = random.random() > 0.5 + ''' + + sample = self.samples[index] + tgt_img = load_as_float(sample['tgt']) + ref_imgs = [load_as_float(ref_img) for ref_img in sample['ref_imgs']] + ''' + if do_color_aug: + color_aug = transforms.ColorJitter.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + else: + color_aug = (lambda x: x) + ''' + + if self.transform is not None: + imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(sample['intrinsics'])) + tgt_img = imgs[0] + ref_imgs = imgs[1:] + else: + intrinsics = np.copy(sample['intrinsics']) + + return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics) + + def __len__(self): + return len(self.samples) diff --git a/datasets_module/validation_folders.py b/datasets_module/validation_folders.py new file mode 100644 index 0000000..d231272 --- /dev/null +++ b/datasets_module/validation_folders.py @@ -0,0 +1,59 @@ +import torch.utils.data as data +import numpy as np +from imageio import imread +from path import Path +import os +import torch + +def crawl_folders(folders_list, dataset='nyu'): + imgs = [] + depths = [] + for folder in folders_list: + current_imgs = sorted(folder.files('*.jpg')) + if dataset == 'nyu': + current_depth = sorted((folder/'depth/').files('*.png')) + elif dataset == 'kitti': + current_depth = sorted(folder.files('*.npy')) + imgs.extend(current_imgs) + depths.extend(current_depth) + return imgs, depths + + +class ValidationSet(data.Dataset): + """A sequence data loader where the files are arranged in this way: + root/scene_1/0000000.jpg + root/scene_1/0000000.npy + root/scene_1/0000001.jpg + root/scene_1/0000001.npy + .. + root/scene_2/0000000.jpg + root/scene_2/0000000.npy + . + + transform functions must take in a list a images and a numpy array which can be None + """ + + def __init__(self, root, transform=None, dataset='nyu'): + self.root = Path(root) + scene_list_path = self.root/'val.txt' + self.scenes = [self.root/folder[:-1] for folder in open(scene_list_path)] + self.transform = transform + self.dataset = dataset + self.imgs, self.depth = crawl_folders(self.scenes, self.dataset) + + def __getitem__(self, index): + img = imread(self.imgs[index]).astype(np.float32) + + if self.dataset=='nyu': + depth = torch.from_numpy(imread(self.depth[index]).astype(np.float32)).float()/5000 + elif self.dataset=='kitti': + depth = torch.from_numpy(np.load(self.depth[index]).astype(np.float32)) + + if self.transform is not None: + img, _ = self.transform([img], None) + img = img[0] + + return img, depth + + def __len__(self): + return len(self.imgs) diff --git a/inverse_warp.py b/inverse_warp.py new file mode 100644 index 0000000..f9789a1 --- /dev/null +++ b/inverse_warp.py @@ -0,0 +1,269 @@ +from __future__ import division +import torch +import torch.nn.functional as F + +pixel_coords = None + + +def set_id_grid(depth): + global pixel_coords + b, h, w = depth.size() + i_range = torch.arange(0, h).view(1, h, 1).expand( + 1, h, w).type_as(depth) # [1, H, W] + j_range = torch.arange(0, w).view(1, 1, w).expand( + 1, h, w).type_as(depth) # [1, H, W] + ones = torch.ones(1, h, w).type_as(depth) + + pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W] + + +def check_sizes(input, input_name, expected): + condition = [input.ndimension() == len(expected)] + for i, size in enumerate(expected): + if size.isdigit(): + condition.append(input.size(i) == int(size)) + assert(all(condition)), "wrong size for {}, expected {}, got {}".format( + input_name, 'x'.join(expected), list(input.size())) + + +def pixel2cam(depth, intrinsics_inv): + global pixel_coords + """Transform coordinates in the pixel frame to the camera frame. + Args: + depth: depth maps -- [B, H, W] + intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3] + Returns: + array of (u,v,1) cam coordinates -- [B, 3, H, W] + """ + b, h, w = depth.size() + if (pixel_coords is None) or pixel_coords.size(2) < h: + set_id_grid(depth) + current_pixel_coords = pixel_coords[:, :, :h, :w].expand( + b, 3, h, w).reshape(b, 3, -1) # [B, 3, H*W] + cam_coords = (intrinsics_inv @ current_pixel_coords).reshape(b, 3, h, w) + return cam_coords * depth.unsqueeze(1) + + +def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode): + """Transform coordinates in the camera frame to the pixel frame. + Args: + cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 4, H, W] + proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4] + proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] + Returns: + array of [-1,1] coordinates -- [B, 2, H, W] + """ + b, _, h, w = cam_coords.size() + cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W] + if proj_c2p_rot is not None: + pcoords = proj_c2p_rot @ cam_coords_flat + else: + pcoords = cam_coords_flat + + if proj_c2p_tr is not None: + pcoords = pcoords + proj_c2p_tr # [B, 3, H*W] + X = pcoords[:, 0] + Y = pcoords[:, 1] + Z = pcoords[:, 2].clamp(min=1e-3) + + # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W] + X_norm = 2*(X / Z)/(w-1) - 1 + Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W] + + pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2] + return pixel_coords.reshape(b, h, w, 2) + + +def euler2mat(angle): + """Convert euler angles to rotation matrix. + Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174 + Args: + angle: rotation angle along 3 axis (in radians) -- size = [B, 3] + Returns: + Rotation matrix corresponding to the euler angles -- size = [B, 3, 3] + """ + B = angle.size(0) + x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] + + cosz = torch.cos(z) + sinz = torch.sin(z) + + zeros = z.detach()*0 + ones = zeros.detach()+1 + zmat = torch.stack([cosz, -sinz, zeros, + sinz, cosz, zeros, + zeros, zeros, ones], dim=1).reshape(B, 3, 3) + + cosy = torch.cos(y) + siny = torch.sin(y) + + ymat = torch.stack([cosy, zeros, siny, + zeros, ones, zeros, + -siny, zeros, cosy], dim=1).reshape(B, 3, 3) + + cosx = torch.cos(x) + sinx = torch.sin(x) + + xmat = torch.stack([ones, zeros, zeros, + zeros, cosx, -sinx, + zeros, sinx, cosx], dim=1).reshape(B, 3, 3) + + rotMat = xmat @ ymat @ zmat + return rotMat + + +def quat2mat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: first three coeff of quaternion of rotation. fourht is then computed to have a norm of 1 -- size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = torch.cat([quat[:, :1].detach()*0 + 1, quat], dim=1) + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, + 1], norm_quat[:, 2], norm_quat[:, 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3) + return rotMat + + +def pose_vec2mat(vec, rotation_mode='euler'): + """ + Convert 6DoF parameters to transformation matrix. + Args:s + vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz -- [B, 6] + Returns: + A transformation matrix -- [B, 3, 4] + """ + translation = vec[:, :3].unsqueeze(-1) # [B, 3, 1] + rot = vec[:, 3:] + if rotation_mode == 'euler': + rot_mat = euler2mat(rot) # [B, 3, 3] + elif rotation_mode == 'quat': + rot_mat = quat2mat(rot) # [B, 3, 3] + transform_mat = torch.cat([rot_mat, translation], dim=2) # [B, 3, 4] + return transform_mat + + +def inverse_warp(img, depth, pose, intrinsics, rotation_mode='euler', padding_mode='zeros'): + """ + Inverse warp a source image to the target image plane. + Args: + img: the source image (where to sample pixels) -- [B, 3, H, W] + depth: depth map of the target image -- [B, H, W] + pose: 6DoF pose parameters from target to source -- [B, 6] + intrinsics: camera intrinsic matrix -- [B, 3, 3] + Returns: + projected_img: Source image warped to the target image plane + valid_points: Boolean array indicating point validity + """ + check_sizes(img, 'img', 'B3HW') + check_sizes(depth, 'depth', 'BHW') + check_sizes(pose, 'pose', 'B6') + check_sizes(intrinsics, 'intrinsics', 'B33') + + batch_size, _, img_height, img_width = img.size() + + cam_coords = pixel2cam(depth, intrinsics.inverse()) # [B,3,H,W] + + pose_mat = pose_vec2mat(pose, rotation_mode) # [B,3,4] + + # Get projection matrix for tgt camera frame to source pixel frame + proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4] + + rot, tr = proj_cam_to_src_pixel[:, :, :3], proj_cam_to_src_pixel[:, :, -1:] + src_pixel_coords = cam2pixel( + cam_coords, rot, tr, padding_mode) # [B,H,W,2] + projected_img = F.grid_sample( + img, src_pixel_coords, padding_mode=padding_mode) + + valid_points = src_pixel_coords.abs().max(dim=-1)[0] <= 1 + + return projected_img, valid_points + + +def cam2pixel2(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode): + """Transform coordinates in the camera frame to the pixel frame. + Args: + cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 4, H, W] + proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4] + proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] + Returns: + array of [-1,1] coordinates -- [B, 2, H, W] + """ + b, _, h, w = cam_coords.size() + cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W] + if proj_c2p_rot is not None: + pcoords = proj_c2p_rot @ cam_coords_flat + else: + pcoords = cam_coords_flat + + if proj_c2p_tr is not None: + pcoords = pcoords + proj_c2p_tr # [B, 3, H*W] + X = pcoords[:, 0] + Y = pcoords[:, 1] + Z = pcoords[:, 2].clamp(min=1e-3) + + # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W] + X_norm = 2*(X / Z)/(w-1) - 1 + Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W] + if padding_mode == 'zeros': + X_mask = ((X_norm > 1)+(X_norm < -1)).detach() + # make sure that no point in warped image is a combinaison of im and gray + X_norm[X_mask] = 2 + Y_mask = ((Y_norm > 1)+(Y_norm < -1)).detach() + Y_norm[Y_mask] = 2 + + pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2] + return pixel_coords.reshape(b, h, w, 2), Z.reshape(b, 1, h, w) + + +def inverse_warp2(img, depth, ref_depth, pose, intrinsics, padding_mode='zeros'): + """ + Inverse warp a source image to the target image plane. + Args: + img: the source image (where to sample pixels) -- [B, 3, H, W] + depth: depth map of the target image -- [B, 1, H, W] + ref_depth: the source depth map (where to sample depth) -- [B, 1, H, W] + pose: 6DoF pose parameters from target to source -- [B, 6] + intrinsics: camera intrinsic matrix -- [B, 3, 3] + Returns: + projected_img: Source image warped to the target image plane + valid_mask: Float array indicating point validity + projected_depth: sampled depth from source image + computed_depth: computed depth of source image using the target depth + """ + check_sizes(img, 'img', 'B3HW') + check_sizes(depth, 'depth', 'B1HW') + check_sizes(ref_depth, 'ref_depth', 'B1HW') + check_sizes(pose, 'pose', 'B6') + check_sizes(intrinsics, 'intrinsics', 'B33') + + batch_size, _, img_height, img_width = img.size() + + cam_coords = pixel2cam(depth.squeeze(1), intrinsics.inverse()) # [B,3,H,W] + + pose_mat = pose_vec2mat(pose) # [B,3,4] + + # Get projection matrix for tgt camera frame to source pixel frame + proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4] + + rot, tr = proj_cam_to_src_pixel[:, :, :3], proj_cam_to_src_pixel[:, :, -1:] + src_pixel_coords, computed_depth = cam2pixel2(cam_coords, rot, tr, padding_mode) # [B,H,W,2] + projected_img = F.grid_sample(img, src_pixel_coords, padding_mode=padding_mode, align_corners=False) + + valid_points = src_pixel_coords.abs().max(dim=-1)[0] <= 1 + valid_mask = valid_points.unsqueeze(1).float() + + projected_depth = F.grid_sample(ref_depth, src_pixel_coords, padding_mode=padding_mode, align_corners=False) + + return projected_img, valid_mask, projected_depth, computed_depth diff --git a/loss_functions.py b/loss_functions.py new file mode 100644 index 0000000..cfc1bb2 --- /dev/null +++ b/loss_functions.py @@ -0,0 +1,287 @@ +from __future__ import division +import torch +from torch import nn +import torch.nn.functional as F +from inverse_warp import inverse_warp2, inverse_warp +import math +import cv2 +import numpy as np + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + +class SSIM(nn.Module): + """Layer to compute the SSIM loss between a pair of images + """ + + def __init__(self): + super(SSIM, self).__init__() + self.mu_x_pool = nn.AvgPool2d(3, 1) + self.mu_y_pool = nn.AvgPool2d(3, 1) + self.sig_x_pool = nn.AvgPool2d(3, 1) + self.sig_y_pool = nn.AvgPool2d(3, 1) + self.sig_xy_pool = nn.AvgPool2d(3, 1) + + self.refl = nn.ReflectionPad2d(1) + + self.C1 = 0.01 ** 2 + self.C2 = 0.03 ** 2 + + def forward(self, x, y): + x = self.refl(x) + y = self.refl(y) + + mu_x = self.mu_x_pool(x) + mu_y = self.mu_y_pool(y) + + sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 + sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 + sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y + + SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) + SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) + + return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) + + +compute_ssim_loss = SSIM().to(device) + + +# photometric loss +# geometry consistency loss + +#### Our addition +def brightnes_equator(source, target): + def image_stats(image): + # compute the mean and standard deviation of each channel + + l = image[:, 0, :, :] + a = image[:, 1, :, :] + b = image[:, 2, :, :] + + (lMean, lStd) = (torch.mean(torch.squeeze(l)), torch.std(torch.squeeze(l))) + + (aMean, aStd) = (torch.mean(torch.squeeze(a)), torch.std(torch.squeeze(a))) + + (bMean, bStd) = (torch.mean(torch.squeeze(b)), torch.std(torch.squeeze(b))) + + # return the color statistics + return (lMean, lStd, aMean, aStd, bMean, bStd) + + def color_transfer(source, target): + # convert the images from the RGB to L*ab* color space, being + # sure to utilizing the floating point data type (note: OpenCV + # expects floats to be 32-bit, so use that instead of 64-bit) + + # compute color statistics for the source and target images + (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = image_stats(source) + (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = image_stats(target) + + # subtract the means from the target image + l = target[:, 0, :, :] + a = target[:, 1, :, :] + b = target[:, 2, :, :] + + l = l - lMeanTar + # print("after l",torch.isnan(l)) + a = a - aMeanTar + b = b - bMeanTar + # scale by the standard deviations + l = (lStdTar / lStdSrc) * l + a = (aStdTar / aStdSrc) * a + b = (bStdTar / bStdSrc) * b + # add in the source mean + l = l + lMeanSrc + a = a + aMeanSrc + b = b + bMeanSrc + transfer = torch.cat((l.unsqueeze(1), a.unsqueeze(1), b.unsqueeze(1)), 1) + # print(torch.isnan(transfer)) + return transfer + + # return the color transferred image + transfered_image = color_transfer(target, source) + return transfered_image + + +def compute_photo_and_geometry_loss(tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, poses, poses_inv, max_scales, + with_ssim, with_mask, with_auto_mask, padding_mode): + photo_loss = 0 + geometry_loss = 0 + + num_scales = min(len(tgt_depth), max_scales) + warp_img_dict = {} + for ref_img, ref_depth, pose, pose_inv in zip(ref_imgs, ref_depths, poses, poses_inv): + for s in range(num_scales): + + # # downsample img + # b, _, h, w = tgt_depth[s].size() + # downscale = tgt_img.size(2)/h + # if s == 0: + # tgt_img_scaled = tgt_img + # ref_img_scaled = ref_img + # else: + # tgt_img_scaled = F.interpolate(tgt_img, (h, w), mode='area') + # ref_img_scaled = F.interpolate(ref_img, (h, w), mode='area') + # intrinsic_scaled = torch.cat((intrinsics[:, 0:2]/downscale, intrinsics[:, 2:]), dim=1) + # tgt_depth_scaled = tgt_depth[s] + # ref_depth_scaled = ref_depth[s] + + # upsample depth + b, _, h, w = tgt_img.size() + tgt_img_scaled = tgt_img + ref_img_scaled = ref_img + intrinsic_scaled = intrinsics + if s == 0: + tgt_depth_scaled = tgt_depth[s] + ref_depth_scaled = ref_depth[s] + else: + tgt_depth_scaled = F.interpolate(tgt_depth[s], (h, w), mode='nearest') + ref_depth_scaled = F.interpolate(ref_depth[s], (h, w), mode='nearest') + + photo_loss1, geometry_loss1, ref_img_warped, ref_img_warped2 = compute_pairwise_loss(tgt_img_scaled, + ref_img_scaled, + tgt_depth_scaled, + ref_depth_scaled, pose, + intrinsic_scaled, + with_ssim, with_mask, + with_auto_mask, + padding_mode) + warp_img_dict["ref_img_warped"] = ref_img_warped + warp_img_dict["ref_img_warped2"] = ref_img_warped2 + photo_loss2, geometry_loss2, tgt_img_warped, tgt_img_warped2 = compute_pairwise_loss(ref_img_scaled, + tgt_img_scaled, + ref_depth_scaled, + tgt_depth_scaled, + pose_inv, + intrinsic_scaled, + with_ssim, with_mask, + with_auto_mask, + padding_mode) + warp_img_dict["tgt_img_warped"] = tgt_img_warped + warp_img_dict["tgt_img_warped2"] = tgt_img_warped2 + + photo_loss += (photo_loss1 + photo_loss2) + geometry_loss += (geometry_loss1 + geometry_loss2) + + return photo_loss, geometry_loss, warp_img_dict + + +def compute_pairwise_loss(tgt_img, ref_img, tgt_depth, ref_depth, pose, intrinsic, with_ssim, with_mask, with_auto_mask, + padding_mode): + ref_img_warped, valid_mask, projected_depth, computed_depth = inverse_warp2(ref_img, tgt_depth, ref_depth, pose, + intrinsic, padding_mode) + + + # print("ref_image_warped",ref_img_warped.shape) + + ref_img_warped2 = brightnes_equator(ref_img_warped, tgt_img) #### Our addition + + diff_img = (tgt_img - ref_img_warped2).abs().clamp(0, 1) + + diff_depth = ((computed_depth - projected_depth).abs() / (computed_depth + projected_depth)).clamp(0, 1) + + if with_auto_mask == True: + auto_mask = (diff_img.mean(dim=1, keepdim=True) < (tgt_img - ref_img).abs().mean(dim=1, + keepdim=True)).float() * valid_mask + valid_mask = auto_mask + + if with_ssim == True: + ssim_map = compute_ssim_loss(tgt_img, ref_img_warped2) #### Our addition + diff_img = (0.15 * diff_img + 0.85 * ssim_map) + + if with_mask == True: + weight_mask = (1 - diff_depth) + diff_img = diff_img * weight_mask + + # compute all loss + reconstruction_loss = mean_on_mask(diff_img, valid_mask) + geometry_consistency_loss = mean_on_mask(diff_depth, valid_mask) + + return reconstruction_loss, geometry_consistency_loss, ref_img_warped, ref_img_warped2 + + +# compute mean value given a binary mask +def mean_on_mask(diff, valid_mask): + mask = valid_mask.expand_as(diff) + if mask.sum() > 10000: + mean_value = (diff * mask).sum() / mask.sum() + else: + mean_value = torch.tensor(0).float().to(device) + return mean_value + + +def compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs): + def get_smooth_loss(disp, img): + """Computes the smoothness loss for a disparity image + The color image is used for edge-aware smoothness + """ + + # normalize + mean_disp = disp.mean(2, True).mean(3, True) + norm_disp = disp / (mean_disp + 1e-7) + disp = norm_disp + + grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) + grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) + + grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) + grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) + + grad_disp_x *= torch.exp(-grad_img_x) + grad_disp_y *= torch.exp(-grad_img_y) + + return grad_disp_x.mean() + grad_disp_y.mean() + + loss = get_smooth_loss(tgt_depth[0], tgt_img) + + for ref_depth, ref_img in zip(ref_depths, ref_imgs): + loss += get_smooth_loss(ref_depth[0], ref_img) + + return loss + + +@torch.no_grad() +def compute_errors(gt, pred, dataset): + abs_diff, abs_rel, sq_rel, a1, a2, a3 = 0, 0, 0, 0, 0, 0 + batch_size, h, w = gt.size() + + ''' + crop used by Garg ECCV16 to reprocude Eigen NIPS14 results + construct a mask of False values, with the same size as target + and then set to True values inside the crop + ''' + if dataset == 'kitti': + crop_mask = gt[0] != gt[0] + y1, y2 = int(0.40810811 * gt.size(1)), int(0.99189189 * gt.size(1)) + x1, x2 = int(0.03594771 * gt.size(2)), int(0.96405229 * gt.size(2)) + crop_mask[y1:y2, x1:x2] = 1 + max_depth = 80 + + if dataset == 'nyu': + crop_mask = gt[0] != gt[0] + y1, y2 = int(0.09375 * gt.size(1)), int(0.98125 * gt.size(1)) + x1, x2 = int(0.0640625 * gt.size(2)), int(0.9390625 * gt.size(2)) + crop_mask[y1:y2, x1:x2] = 1 + max_depth = 10 + + for current_gt, current_pred in zip(gt, pred): + valid = (current_gt > 0.1) & (current_gt < max_depth) + valid = valid & crop_mask + + valid_gt = current_gt[valid] + valid_pred = current_pred[valid].clamp(1e-3, max_depth) + + valid_pred = valid_pred * torch.median(valid_gt) / torch.median(valid_pred) + + thresh = torch.max((valid_gt / valid_pred), (valid_pred / valid_gt)) + a1 += (thresh < 1.25).float().mean() + a2 += (thresh < 1.25 ** 2).float().mean() + a3 += (thresh < 1.25 ** 3).float().mean() + + abs_diff += torch.mean(torch.abs(valid_gt - valid_pred)) + abs_rel += torch.mean(torch.abs(valid_gt - valid_pred) / valid_gt) + + sq_rel += torch.mean(((valid_gt - valid_pred) ** 2) / valid_gt) + + return [metric.item() / batch_size for metric in [abs_diff, abs_rel, sq_rel, a1, a2, a3]] + diff --git a/models/DispResNet.py b/models/DispResNet.py new file mode 100644 index 0000000..7df15c5 --- /dev/null +++ b/models/DispResNet.py @@ -0,0 +1,140 @@ +from __future__ import absolute_import, division, print_function +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .resnet_encoder2 import * + + +import numpy as np +from collections import OrderedDict + +class ConvBlock(nn.Module): + """Layer to perform a convolution followed by ELU + """ + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + + self.conv = Conv3x3(in_channels, out_channels) + self.nonlin = nn.ELU(inplace=False) + + def forward(self, x): + out = self.conv(x) + out = self.nonlin(out) + return out + +class Conv3x3(nn.Module): + """Layer to pad and convolve input + """ + def __init__(self, in_channels, out_channels, use_refl=True): + super(Conv3x3, self).__init__() + + if use_refl: + self.pad = nn.ReflectionPad2d(1) + else: + self.pad = nn.ZeroPad2d(1) + self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) + + def forward(self, x): + out = self.pad(x) + out = self.conv(out) + return out + +def upsample(x): + """Upsample input tensor by a factor of 2 + """ + return F.interpolate(x, scale_factor=2, mode="nearest") + +class DepthDecoder(nn.Module): + def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True): + super(DepthDecoder, self).__init__() + + self.alpha = 10 + self.beta = 0.01 + + self.num_output_channels = num_output_channels + self.use_skips = use_skips + self.upsample_mode = 'nearest' + self.scales = scales + + self.num_ch_enc = num_ch_enc + self.num_ch_dec = np.array([16, 32, 64, 128, 256]) + + # decoder + self.convs = OrderedDict() + for i in range(4, -1, -1): + # upconv_0 + num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) + + # upconv_1 + num_ch_in = self.num_ch_dec[i] + if self.use_skips and i > 0: + num_ch_in += self.num_ch_enc[i - 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) + + for s in self.scales: + self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) + + self.decoder = nn.ModuleList(list(self.convs.values())) + self.sigmoid = nn.Sigmoid() + + def forward(self, input_features): + self.outputs = [] + + # decoder + x = input_features[-1] + for i in range(4, -1, -1): + x = self.convs[("upconv", i, 0)](x) + x = [upsample(x)] + if self.use_skips and i > 0: + x += [input_features[i - 1]] + x = torch.cat(x, 1) + x = self.convs[("upconv", i, 1)](x) + if i in self.scales: + self.outputs.append(self.alpha * self.sigmoid(self.convs[("dispconv", i)](x)) + self.beta) + + self.outputs = self.outputs[::-1] + return self.outputs + + +class DispResNet(nn.Module): + + def __init__(self, num_layers = 18, pretrained = True): + super(DispResNet, self).__init__() + self.encoder = ResnetEncoder(num_layers = num_layers, pretrained = pretrained, num_input_images=1) + self.decoder = DepthDecoder(self.encoder.num_ch_enc) + + def init_weights(self): + pass + + def forward(self, x): + features = self.encoder(x) + outputs = self.decoder(features) + + if self.training: + return outputs + else: + return outputs[0] + + +if __name__ == "__main__": + + torch.backends.cudnn.benchmark = True + + model = DispResNet().cuda() + model.train() + + B = 12 + + tgt_img = torch.randn(B, 3, 256, 832).cuda() + ref_imgs = [torch.randn(B, 3, 256, 832).cuda() for i in range(2)] + + tgt_depth = model(tgt_img) + + print(tgt_depth[0].size()) + + diff --git a/models/PoseResNet.py b/models/PoseResNet.py new file mode 100644 index 0000000..2c0859d --- /dev/null +++ b/models/PoseResNet.py @@ -0,0 +1,82 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +from __future__ import absolute_import, division, print_function + +import torch +import torch.nn as nn +from collections import OrderedDict +from .resnet_encoder import * + +class PoseDecoder(nn.Module): + def __init__(self, num_ch_enc, num_input_features=1, num_frames_to_predict_for=1, stride=1): + super(PoseDecoder, self).__init__() + + self.num_ch_enc = num_ch_enc + self.num_input_features = num_input_features + + if num_frames_to_predict_for is None: + num_frames_to_predict_for = num_input_features - 1 + self.num_frames_to_predict_for = num_frames_to_predict_for + + self.convs = OrderedDict() + self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) + self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) + self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) + self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) + + self.relu = nn.ReLU(inplace=False) + + self.net = nn.ModuleList(list(self.convs.values())) + + def forward(self, input_features): + last_features = [f[-1] for f in input_features] + + cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] + cat_features = torch.cat(cat_features, 1) + + out = cat_features + for i in range(3): + out = self.convs[("pose", i)](out) + if i != 2: + out = self.relu(out) + + out = out.mean(3).mean(2) + + pose = 0.01 * out.view(-1, 6) + + return pose + + +class PoseResNet(nn.Module): + + def __init__(self, num_layers = 18, pretrained = True): + super(PoseResNet, self).__init__() + self.encoder = ResnetEncoder(num_layers = num_layers, pretrained = pretrained, num_input_images=2) + self.decoder = PoseDecoder(self.encoder.num_ch_enc) + + def init_weights(self): + pass + + def forward(self, img1, img2): + x = torch.cat([img1,img2],1) + features = self.encoder(x) + pose = self.decoder([features]) + return pose + +if __name__ == "__main__": + + torch.backends.cudnn.benchmark = True + + model = PoseResNet().cuda() + model.train() + + tgt_img = torch.randn(4, 3, 256, 832).cuda() + ref_imgs = [torch.randn(4, 3, 256, 832).cuda() for i in range(2)] + + pose = model(tgt_img, ref_imgs[0]) + + print(pose.size()) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..97d21f8 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,2 @@ +from .DispResNet import DispResNet +from .PoseResNet import PoseResNet diff --git a/models/resnet_encoder.py b/models/resnet_encoder.py new file mode 100644 index 0000000..660fc6b --- /dev/null +++ b/models/resnet_encoder.py @@ -0,0 +1,138 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +from __future__ import absolute_import, division, print_function + +import numpy as np + +import torch +import torch.nn as nn +import torchvision.models as models +import torch.utils.model_zoo as model_zoo + + +class ResNetMultiImageInput(models.ResNet): + """Constructs a resnet model with varying number of input images. + Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + """ + def __init__(self, block, layers, num_classes=1000, num_input_images=1): + super(ResNetMultiImageInput, self).__init__(block, layers) + self.inplanes = 64 + self.conv1 = nn.Conv2d( + num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0.2) + + +def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): + """Constructs a ResNet model. + Args: + num_layers (int): Number of resnet layers. Must be 18 or 50 + pretrained (bool): If True, returns a model pre-trained on ImageNet + num_input_images (int): Number of frames stacked as input + """ + assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" + blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] + block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] + model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) + + if pretrained: + loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) + loaded['conv1.weight'] = torch.cat( + [loaded['conv1.weight']] * num_input_images, 1) / num_input_images + model.load_state_dict(loaded) + return model + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=1): + super(SpatialAttention, self).__init__() + kernel_size = 1 + padding = 3 if kernel_size == 7 else 0 + + self.conv1 = nn.Conv2d(64, 4, kernel_size, padding=padding, bias=False) + self.conv2 = nn.Conv2d(4, 4, kernel_size , padding=padding, bias=False) + self.conv3 = nn.Conv2d(4, 64, kernel_size, padding=padding, bias=False) + self.maxPooling = nn.MaxPool2d(4,stride=4) + self.relu = nn.ReLU() + self.softmax = nn.Softmax(dim=1) + self.upsample = nn.Upsample(scale_factor=4) + + def forward(self, x): + + x1 = self.conv1(x) + x2 = self.maxPooling(x1) + reshaped1 = torch.reshape(x2,(x2.shape[0],x2.shape[1],-1,x2.shape[2])) + y = torch.matmul(reshaped1,x2) + z = self.relu(y) + z = self.conv2(z) + t = self.softmax(z) + out1 = torch.matmul(t,reshaped1) + conv3_out = self.conv3(out1) + upsample_out = self.upsample(conv3_out) + k = torch.reshape(upsample_out,(upsample_out.shape[0],upsample_out.shape[1],-1,upsample_out.shape[2])) + output = k + x + + return output + + +class ResnetEncoder(nn.Module): + """Pytorch module for a resnet encoder + """ + def __init__(self, num_layers, pretrained, num_input_images=1): + super(ResnetEncoder, self).__init__() + + self.num_ch_enc = np.array([64, 64, 128, 256, 512]) + + resnets = {18: models.resnet18, + 34: models.resnet34, + 50: models.resnet50, + 101: models.resnet101, + 152: models.resnet152} + + if num_layers not in resnets: + raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) + + if num_input_images > 1: + self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) + else: + self.encoder = resnets[num_layers](pretrained) + + if num_layers > 34: + self.num_ch_enc[1:] *= 4 + + self.SAB = SpatialAttention() + + def forward(self, input_image): + self.features = [] + x = input_image + x = self.encoder.conv1(x) + x = self.encoder.bn1(x) + self.features.append(self.encoder.relu(x)) + + #denemee + + self.features.append(self.SAB(self.features[-1])) + + + self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) + self.features.append(self.encoder.layer2(self.features[-1])) + self.features.append(self.encoder.layer3(self.features[-1])) + self.features.append(self.encoder.layer4(self.features[-1])) + + return self.features + diff --git a/models/resnet_encoder2.py b/models/resnet_encoder2.py new file mode 100644 index 0000000..432bbba --- /dev/null +++ b/models/resnet_encoder2.py @@ -0,0 +1,103 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +from __future__ import absolute_import, division, print_function + +import numpy as np + +import torch +import torch.nn as nn +import torchvision.models as models +import torch.utils.model_zoo as model_zoo + + +class ResNetMultiImageInput(models.ResNet): + """Constructs a resnet model with varying number of input images. + Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + """ + def __init__(self, block, layers, num_classes=1000, num_input_images=1): + super(ResNetMultiImageInput, self).__init__(block, layers) + self.inplanes = 64 + self.conv1 = nn.Conv2d( + num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): + """Constructs a ResNet model. + Args: + num_layers (int): Number of resnet layers. Must be 18 or 50 + pretrained (bool): If True, returns a model pre-trained on ImageNet + num_input_images (int): Number of frames stacked as input + """ + assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" + blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] + block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] + model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) + + if pretrained: + loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) + loaded['conv1.weight'] = torch.cat( + [loaded['conv1.weight']] * num_input_images, 1) / num_input_images + model.load_state_dict(loaded) + return model + + + +class ResnetEncoder(nn.Module): + """Pytorch module for a resnet encoder + """ + def __init__(self, num_layers, pretrained, num_input_images=1): + super(ResnetEncoder, self).__init__() + + self.num_ch_enc = np.array([64, 64, 128, 256, 512]) + + resnets = {18: models.resnet18, + 34: models.resnet34, + 50: models.resnet50, + 101: models.resnet101, + 152: models.resnet152} + + if num_layers not in resnets: + raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) + + if num_input_images > 1: + self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) + else: + self.encoder = resnets[num_layers](pretrained) + + if num_layers > 34: + self.num_ch_enc[1:] *= 4 + + + def forward(self, input_image): + self.features = [] + x = input_image + + x = self.encoder.conv1(x) + + x = self.encoder.bn1(x) + self.features.append(self.encoder.relu(x)) + + self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) + self.features.append(self.encoder.layer2(self.features[-1])) + self.features.append(self.encoder.layer3(self.features[-1])) + self.features.append(self.encoder.layer4(self.features[-1])) + + return self.features diff --git a/train.py b/train.py new file mode 100644 index 0000000..fa84a95 --- /dev/null +++ b/train.py @@ -0,0 +1,439 @@ +import argparse +import time +import csv +import datetime +from path import Path +import cv2 + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data + +import models + +import custom_transforms +from utils import tensor2array, save_checkpoint +from datasets_module.sequence_folders import SequenceFolder +from datasets_module.pair_folders import PairFolder +from loss_functions import compute_smooth_loss, compute_photo_and_geometry_loss, compute_errors +from tensorboardX import SummaryWriter +import os + +parser = argparse.ArgumentParser(description='Structure from Motion Learner training on KITTI and CityScapes Dataset', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +parser.add_argument('--data', metavar='DIR', type=str, default='./datasets/data_for_SC_SfMLearner', help='path to dataset') +parser.add_argument('--folder-type', type=str, choices=['sequence', 'pair'], default='sequence', + help='the dataset dype to train') +parser.add_argument('--sequence-length', type=int, metavar='N', help='sequence length for training', default=3) +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers') +parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') +parser.add_argument('--epoch-size', default=0, type=int, metavar='N', + help='manual epoch size (will match dataset size if not set)') +parser.add_argument('-b', '--batch-size', default=4, type=int, metavar='N', help='mini-batch size') +parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum for sgd, alpha parameter for adam') +parser.add_argument('--beta', default=0.999, type=float, metavar='M', help='beta parameters for adam') +parser.add_argument('--weight-decay', '--wd', default=0, type=float, metavar='W', help='weight decay') +parser.add_argument('--print-freq', default=1000, type=int, metavar='N', help='print frequency') +parser.add_argument('--seed', default=0, type=int, help='seed for random functions, and network initialization') +parser.add_argument('--log-summary', default='progress_log_summary.csv', metavar='PATH', + help='csv where to save per-epoch train and valid stats') +parser.add_argument('--log-full', default='progress_log_full.csv', metavar='PATH', + help='csv where to save per-gradient descent train stats') +parser.add_argument('--log-output', action='store_true', help='will log dispnet outputs at validation step') +parser.add_argument('--resnet-layers', type=int, default=18, choices=[18, 50], + help='number of ResNet layers for depth estimation.') +parser.add_argument('--num-scales', '--number-of-scales', type=int, help='the number of scales', metavar='W', default=1) +parser.add_argument('-p', '--photo-loss-weight', type=float, help='weight for photometric loss', metavar='W', default=1) +parser.add_argument('-s', '--smooth-loss-weight', type=float, help='weight for disparity smoothness loss', metavar='W', + default=0.1) +parser.add_argument('-c', '--geometry-consistency-weight', type=float, help='weight for depth consistency loss', + metavar='W', default=0.5) +parser.add_argument('--with-ssim', type=int, default=1, help='with ssim or not') +parser.add_argument('--with-mask', type=int, default=1, + help='with the the mask for moving objects and occlusions or not') +parser.add_argument('--with-auto-mask', type=int, default=0, help='with the the mask for stationary points') +parser.add_argument('--with-pretrain', type=int, default=1, help='with or without imagenet pretrain for resnet') +parser.add_argument('--dataset', type=str, choices=['kitti', 'nyu'], default='kitti', help='the dataset to train') +parser.add_argument('--pretrained-disp', dest='pretrained_disp', default=None, metavar='PATH', + help='path to pre-trained dispnet model') +parser.add_argument('--pretrained-pose', dest='pretrained_pose', default=None, metavar='PATH', + help='path to pre-trained Pose net model') +parser.add_argument('--padding-mode', type=str, choices=['zeros', 'border'], default='zeros', + help='padding mode for image warping : this is important for photometric differenciation when going outside target image.' + ' zeros will null gradients outside target image.' + ' border will only null gradients of the coordinate outside (x or y)') +parser.add_argument('--with-gt', action='store_true', help='use ground truth for validation. \ + You need to store it in npy 2D arrays see data/kitti_raw_loader.py for an example') + +best_error = -1 +n_iter = 0 + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +torch.autograd.set_detect_anomaly(True) + + +def main(): + global best_error, n_iter, device + args = parser.parse_args() + + timestamp = Path(datetime.datetime.now().strftime("%m-%d-%H-%M")) + args.save_path = 'checkpoints' / timestamp + print('=> will save everything to {}'.format(args.save_path)) + args.save_path.makedirs_p() + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + cudnn.deterministic = True + cudnn.benchmark = True + + training_writer = SummaryWriter(args.save_path) + ''' + if args.log_output: + for i in range(3): + output_writers.append(SummaryWriter(args.save_path / 'valid' / str(i))) + ''' + + # Data loading code + normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45], + std=[0.225, 0.225, 0.225]) + + train_transform = custom_transforms.Compose([ + custom_transforms.RandomHorizontalFlip(), + custom_transforms.RandomScaleCrop(), + custom_transforms.ArrayToTensor(), + normalize + ]) + + valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize]) + + print("=> fetching scenes in '{}'".format(args.data)) + if args.folder_type == 'sequence': + train_set = SequenceFolder( + args.data, + transform=train_transform, + seed=args.seed, + train=True, + sequence_length=args.sequence_length, + dataset=args.dataset + ) + else: + train_set = PairFolder( + args.data, + seed=args.seed, + train=True, + transform=train_transform + ) + + # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping + if args.with_gt: + from datasets_module.validation_folders import ValidationSet + val_set = ValidationSet( + args.data, + transform=valid_transform, + dataset=args.dataset + ) + else: + val_set = SequenceFolder( + args.data, + transform=valid_transform, + seed=args.seed, + train=False, + sequence_length=args.sequence_length, + dataset=args.dataset + ) + print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) + print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) + train_loader = torch.utils.data.DataLoader( + train_set, batch_size=args.batch_size, shuffle=True, + num_workers=0, pin_memory=True, drop_last=True) + val_loader = torch.utils.data.DataLoader( + val_set, batch_size=args.batch_size, shuffle=False, + num_workers=0, pin_memory=True, drop_last=True) + + if args.epoch_size == 0: + args.epoch_size = len(train_loader) + + # create model + print("=> creating model") + disp_net = models.DispResNet(args.resnet_layers, True).to(device) + pose_net = models.PoseResNet(18, True).to(device) + + # load parameters + if args.pretrained_disp: + print("=> using pre-trained weights for DispResNet") + weights = torch.load(args.pretrained_disp) + disp_net.load_state_dict(weights['state_dict'], strict=False) + + if args.pretrained_pose: + print("=> using pre-trained weights for PoseResNet") + weights = torch.load(args.pretrained_pose) + pose_net.load_state_dict(weights['state_dict'], strict=False) + + print('=> setting adam solver') + optim_params = [ + {'params': disp_net.parameters(), 'lr': args.lr}, + {'params': pose_net.parameters(), 'lr': args.lr} + ] + optimizer = torch.optim.Adam(optim_params, + betas=(args.momentum, args.beta), + weight_decay=args.weight_decay) + + + for epoch in range(args.epochs): + train_loss = train(args, train_loader, disp_net, pose_net, optimizer, args.epoch_size, None, training_writer) + + ''' + save_checkpoint( + args.save_path, { + 'epoch': epoch + 1, + 'state_dict': disp_net.module.state_dict() + }, { + 'epoch': epoch + 1, + 'state_dict': pose_net.module.state_dict() + }, + is_best) + ''' + # logger.epoch_bar.finish() + + +def normalize_image(x): + """Rescale image pixels to span range [0, 1] + """ + ma = float(x.max().cpu().data) + mi = float(x.min().cpu().data) + d = ma - mi if ma != mi else 1e5 + return (x - mi) / d + + +def train(args, train_loader, disp_net, pose_net, optimizer, epoch_size, logger, train_writer): + global n_iter, device + w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.geometry_consistency_weight + + # switch to train mode + disp_net.train() + pose_net.train() + + end = time.time() + + for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(train_loader): + log_losses = n_iter % args.print_freq == 0 + + # measure data loading time + # data_time.update(time.time() - end) + tgt_img = tgt_img.to(device) + ref_imgs = [img.to(device) for img in ref_imgs] + intrinsics = intrinsics.to(device) + + # compute output + tgt_depth, ref_depths, tgt_disp, ref_disp = compute_depth(disp_net, tgt_img, ref_imgs) + poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs) + + loss_1, loss_3, warped_img_dict = compute_photo_and_geometry_loss(tgt_img, ref_imgs, intrinsics, tgt_depth, + ref_depths, + poses, poses_inv, args.num_scales, + args.with_ssim, + args.with_mask, args.with_auto_mask, + args.padding_mode) + + loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs) + + loss = w1 * loss_1 + w2 * loss_2 + w3 * loss_3 + + if log_losses: + print(f"n_iter:{n_iter} || p_loss:{loss_1.item()} || d_loss:{loss_2.item()} || g_loss:{loss_3.item()}") + + torch.save(disp_net.state_dict(), os.path.join(args.save_path, f'dispnet_{n_iter}.pth')) + torch.save(pose_net.state_dict(), os.path.join(args.save_path, f'posenet_{n_iter}.pth')) + + train_writer.add_scalar('photometric_error', loss_1.item(), n_iter) + train_writer.add_scalar('disparity_smoothness_loss', loss_2.item(), n_iter) + train_writer.add_scalar('geometry_consistency_loss', loss_3.item(), n_iter) + train_writer.add_scalar('total_loss', loss.item(), n_iter) + + for j in range(min(4, args.batch_size)): + train_writer.add_image("tgt_img/{}".format(j), normalize_image(tgt_img[j].data), n_iter) + train_writer.add_image("tgt_disp/{}".format(j), normalize_image(tgt_disp[0][j].data), n_iter) + for k, ref_img in enumerate(ref_imgs): + train_writer.add_image("ref_img_{}/{}".format(k, j), normalize_image(ref_img[j].data), n_iter) + for k, r_disp in enumerate(ref_disp): + train_writer.add_image("ref_disp_{}/{}".format(k, j), normalize_image(r_disp[0][j].data), n_iter) + train_writer.add_image("warped_tgt/{}".format(j), normalize_image(warped_img_dict["tgt_img_warped"][j]), + n_iter) + train_writer.add_image("warped_tgt2/{}".format(j), + normalize_image(warped_img_dict["tgt_img_warped2"][j]), n_iter) + train_writer.add_image("warped_ref/{}".format(j), normalize_image(warped_img_dict["ref_img_warped"][j]), + n_iter) + train_writer.add_image("warped_ref2/{}".format(j), + normalize_image(warped_img_dict["ref_img_warped2"][j]), n_iter) + + # compute gradient and do Adam step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if i >= epoch_size - 1: + break + + n_iter += 1 + + # return losses.avg[0] + + +@torch.no_grad() +def validate_without_gt(args, val_loader, disp_net, pose_net, epoch, logger, output_writers=[]): + global device + # batch_time = AverageMeter() + # losses = AverageMeter(i=4, precision=4) + log_outputs = len(output_writers) > 0 + + # switch to evaluate mode + disp_net.eval() + pose_net.eval() + + end = time.time() + # logger.valid_bar.update(0) + for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader): + tgt_img = tgt_img.to(device) + ref_imgs = [img.to(device) for img in ref_imgs] + intrinsics = intrinsics.to(device) + intrinsics_inv = intrinsics_inv.to(device) + + # compute output + tgt_depth = [1 / disp_net(tgt_img)] + ref_depths = [] + for ref_img in ref_imgs: + ref_depth = [1 / disp_net(ref_img)] + ref_depths.append(ref_depth) + + if log_outputs and i < len(output_writers): + if epoch == 0: + output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0) + + output_writers[i].add_image('val Dispnet Output Normalized', + tensor2array(1 / tgt_depth[0][0], max_value=None, colormap='magma'), + epoch) + output_writers[i].add_image('val Depth Output', + tensor2array(tgt_depth[0][0], max_value=10), + epoch) + + poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs) + + loss_1, loss_3 = compute_photo_and_geometry_loss(tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths, + poses, poses_inv, args.num_scales, args.with_ssim, + args.with_mask, False, args.padding_mode) + + loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs) + + loss_1 = loss_1.item() + loss_2 = loss_2.item() + loss_3 = loss_3.item() + + loss = loss_1 + # losses.update([loss, loss_1, loss_2, loss_3]) + + # measure elapsed time + # batch_time.update(time.time() - end) + end = time.time() + # logger.valid_bar.update(i+1) + # if i % args.print_freq == 0: + # logger.valid_writer.write('valid: Time {} Loss {}'.format(batch_time, losses)) + + # logger.valid_bar.update(len(val_loader)) + # return losses.avg, ['Total loss', 'Photo loss', 'Smooth loss', 'Consistency loss'] + + +''' +@torch.no_grad() +def validate_with_gt(args, val_loader, disp_net, epoch, logger, output_writers=[]): + global device + batch_time = AverageMeter() + error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3'] + errors = AverageMeter(i=len(error_names)) + log_outputs = len(output_writers) > 0 + + # switch to evaluate mode + disp_net.eval() + + end = time.time() + logger.valid_bar.update(0) + for i, (tgt_img, depth) in enumerate(val_loader): + tgt_img = tgt_img.to(device) + depth = depth.to(device) + + # check gt + if depth.nelement() == 0: + continue + + # compute output + output_disp = disp_net(tgt_img) + output_depth = 1/output_disp[:, 0] + + if log_outputs and i < len(output_writers): + if epoch == 0: + output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0) + depth_to_show = depth[0] + output_writers[i].add_image('val target Depth', + tensor2array(depth_to_show, max_value=10), + epoch) + depth_to_show[depth_to_show == 0] = 1000 + disp_to_show = (1/depth_to_show).clamp(0, 10) + output_writers[i].add_image('val target Disparity Normalized', + tensor2array(disp_to_show, max_value=None, colormap='magma'), + epoch) + + output_writers[i].add_image('val Dispnet Output Normalized', + tensor2array(output_disp[0], max_value=None, colormap='magma'), + epoch) + output_writers[i].add_image('val Depth Output', + tensor2array(output_depth[0], max_value=10), + epoch) + + if depth.nelement() != output_depth.nelement(): + b, h, w = depth.size() + output_depth = torch.nn.functional.interpolate(output_depth.unsqueeze(1), [h, w]).squeeze(1) + + errors.update(compute_errors(depth, output_depth, args.dataset)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + logger.valid_bar.update(i+1) + if i % args.print_freq == 0: + logger.valid_writer.write('valid: Time {} Abs Error {:.4f} ({:.4f})'.format(batch_time, errors.val[0], errors.avg[0])) + logger.valid_bar.update(len(val_loader)) + return errors.avg, error_names +''' + + +def compute_depth(disp_net, tgt_img, ref_imgs): + tgt_disp = disp_net(tgt_img) + tgt_depth = [1 / disp for disp in tgt_disp] + + ref_depths = [] + ref_disps = [] + for ref_img in ref_imgs: + disps = disp_net(ref_img) + ref_disps.append(disps) + ref_depth = [1 / disp for disp in disps] + ref_depths.append(ref_depth) + + return tgt_depth, ref_depths, tgt_disp, ref_disps + + +def compute_pose_with_inv(pose_net, tgt_img, ref_imgs): + poses = [] + poses_inv = [] + for ref_img in ref_imgs: + poses.append(pose_net(tgt_img, ref_img)) + poses_inv.append(pose_net(ref_img, tgt_img)) + + return poses, poses_inv + + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..0031fd7 --- /dev/null +++ b/utils.py @@ -0,0 +1,67 @@ +from __future__ import division +import shutil +import numpy as np +import torch +from path import Path +import datetime +from collections import OrderedDict +from matplotlib import cm +from matplotlib.colors import ListedColormap, LinearSegmentedColormap + + +def high_res_colormap(low_res_cmap, resolution=1000, max_value=1): + # Construct the list colormap, with interpolated values for higer resolution + # For a linear segmented colormap, you can just specify the number of point in + # cm.get_cmap(name, lutsize) with the parameter lutsize + x = np.linspace(0, 1, low_res_cmap.N) + low_res = low_res_cmap(x) + new_x = np.linspace(0, max_value, resolution) + high_res = np.stack([np.interp(new_x, x, low_res[:, i]) + for i in range(low_res.shape[1])], axis=1) + return ListedColormap(high_res) + + +def opencv_rainbow(resolution=1000): + # Construct the opencv equivalent of Rainbow + opencv_rainbow_data = ( + (0.000, (1.00, 0.00, 0.00)), + (0.400, (1.00, 1.00, 0.00)), + (0.600, (0.00, 1.00, 0.00)), + (0.800, (0.00, 0.00, 1.00)), + (1.000, (0.60, 0.00, 1.00)) + ) + + return LinearSegmentedColormap.from_list('opencv_rainbow', opencv_rainbow_data, resolution) + + +COLORMAPS = {'rainbow': opencv_rainbow(), + 'magma': high_res_colormap(cm.get_cmap('magma')), + 'bone': cm.get_cmap('bone', 10000)} + + +def tensor2array(tensor, max_value=None, colormap='rainbow'): + tensor = tensor.detach().cpu() + if max_value is None: + max_value = tensor.max().item() + if tensor.ndimension() == 2 or tensor.size(0) == 1: + norm_array = tensor.squeeze().numpy()/max_value + # array = COLORMAPS[colormap](norm_array).astype(np.float32) + array = norm_array[:, :, np.newaxis].astype(np.float32) + array = array.transpose(2, 0, 1) + + elif tensor.ndimension() == 3: + assert(tensor.size(0) == 3) + array = 0.45 + tensor.numpy()*0.225 + return array + + +def save_checkpoint(save_path, dispnet_state, exp_pose_state, is_best, filename='checkpoint.pth.tar'): + file_prefixes = ['dispnet', 'exp_pose'] + states = [dispnet_state, exp_pose_state] + for (prefix, state) in zip(file_prefixes, states): + torch.save(state, save_path/'{}_{}'.format(prefix, filename)) + + if is_best: + for prefix in file_prefixes: + shutil.copyfile(save_path/'{}_{}'.format(prefix, filename), + save_path/'{}_model_best.pth.tar'.format(prefix))