import sys, os

sys.path.append(os.path.abspath(".."))

import torch
import cv2
import numpy as np
import models
import argparse
from utils import tensor2array
from glob import glob
import os.path as osp
import yaml
from tqdm import tqdm

parser = argparse.ArgumentParser(description='Structure from Motion Learner',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--weight_pth', required=True, help="path to disp model's pth file")
parser.add_argument('--path_to_yaml', default="../datasets/data_for_SC_SfMLearner/environment.yaml", help="path to environent")
parser.add_argument('--use_camera', type=int, default=-9999, help="カメラを使うならポート番号を指定")
args = parser.parse_args()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
disp_net = models.DispResNet(18, False).to(device)
# disp_net = torch.nn.DataParallel(disp_net)
weights = torch.load(args.weight_pth)
disp_net.load_state_dict(weights)
disp_net.eval()

with open(args.path_to_yaml) as f:
    environment = yaml.safe_load(f)
r_h, r_w = environment["dataset_info"]["height"], environment["dataset_info"]["width"]

if args.use_camera != -9999:
    cap = cv2.VideoCapture(args.use_camera)
    w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    print("「q」キーを押すことでプログラムが終了します．")

    while True:
        ret, input_array = cap.read()
        cv2.imshow("src", input_array)

        if not ret:
            break

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

        input_array = cv2.resize(input_array, (r_w, r_h), interpolation=cv2.INTER_LINEAR)
        input_array = cv2.cvtColor(input_array, cv2.COLOR_BGR2RGB)
        input_array = np.transpose(input_array, (2, 0, 1)).astype(np.float32)

        input_tensor = torch.from_numpy(input_array).unsqueeze(0)
        input_tensor = ((input_tensor / 255 - 0.45) / 0.225).to(device)

        output = disp_net(input_tensor)[0]

        disp = (255 * tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8)
        disp = np.transpose(disp, (1, 2, 0))
        disp = cv2.resize(disp, (w, h), interpolation=cv2.INTER_LINEAR)
        cv2.imshow("disp", disp)
        cv2.waitKey(1)

    cap.release()

else:
    target_files_list = glob(osp.join("input", "*.*"))
    img_ext = ["jpg", "png"]
    video_ext = ["mp4"]

    for target_file in tqdm(target_files_list):
        ext = target_file.lower().split('.')[-1]

        if ext in img_ext:
            input_array = cv2.imread(target_file)
            h, w, _ = input_array.shape
            input_array = cv2.resize(input_array, (r_w, r_h), cv2.INTER_LINEAR)
            input_array = cv2.cvtColor(input_array, cv2.COLOR_BGR2RGB)
            input_array = np.transpose(input_array, (2, 0, 1)).astype(np.float32)

            input_tensor = torch.from_numpy(input_array).unsqueeze(0)
            input_tensor = ((input_tensor/255 - 0.45)/0.225).to(device)

            output = disp_net(input_tensor)[0]

            disp = (255*tensor2array(output, max_value=None)).astype(np.uint8)
            disp = np.transpose(disp, (1, 2, 0))
            disp = cv2.resize(disp, (w, h), interpolation=cv2.INTER_LINEAR)
            cv2.imwrite(osp.join("./output", osp.basename(target_file)), disp)

        elif ext in video_ext:
            cap = cv2.VideoCapture(target_file)

            w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            fps = cap.get(cv2.CAP_PROP_FPS)
            fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
            writer = cv2.VideoWriter(osp.join("./output", osp.basename(target_file)), fourcc, fps, (w, h))

            while True:
                ret, input_array = cap.read()

                if not ret:
                    break

                input_array = cv2.resize(input_array, (r_w, r_h), interpolation=cv2.INTER_LINEAR)
                input_array = cv2.cvtColor(input_array, cv2.COLOR_BGR2RGB)
                input_array = np.transpose(input_array, (2, 0, 1)).astype(np.float32)

                input_tensor = torch.from_numpy(input_array).unsqueeze(0)
                input_tensor = ((input_tensor / 255 - 0.45) / 0.225).to(device)

                output = disp_net(input_tensor)[0]

                disp = (255 * tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8)
                disp = np.transpose(disp, (1, 2, 0))
                disp = cv2.resize(disp, (w, h), interpolation=cv2.INTER_LINEAR)
                writer.write(cv2.cvtColor(disp, cv2.COLOR_GRAY2BGR))

            cap.release()
            writer.release()
