import torch
import cv2
import numpy as np
import networks

from PIL import Image
from torchvision import transforms

pth_base_path = r"D:\Deep_Learning\SC_SfMLearner\esophagus\log\mdp(EndoSLAM)\models\weights_99\{}.pth"
video_path = r"D:\Deep_Learning\MonoDepth2\esophagus\movies\trimed\0.mp4"

fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
disp_out = cv2.VideoWriter(r"C:\Users\Planck\OneDrive - chiba-u.jp\デスクトップ\0_次のミーティングにつかうやつ\20201120\disp.mp4", fourcc, 30.0, (480, 352))

def normalize_image(x):
    """Rescale image pixels to span range [0, 1]
    """
    ma = float(x.max().cpu().data)
    mi = float(x.min().cpu().data)
    d = ma - mi if ma != mi else 1e5
    return (x - mi) / d

cap = cv2.VideoCapture(video_path)
scaled_intrinsic = np.load("./params/intrinsics_scaled.npy")
dist_coeffs = np.load("./params/dist_coeffs.npy")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

models = {}

models["encoder"] = networks.ResnetEncoder(18, False)
models["encoder"].to(device)
models["depth"] = networks.DepthDecoder(models["encoder"].num_ch_enc, [0])
models["depth"].to(device)

models["pose_encoder"] = networks.ESAB_Encoder(18, False, 2)
models["pose_encoder"].to(device)
models["pose"] = networks.PoseDecoder(models["pose_encoder"].num_ch_enc, 1, 2)
models["pose"].to(device)

for n in ["encoder", "depth", "pose_encoder", "pose"]:
    print("Loading {} weights...".format(n))
    path = pth_base_path.format(n)
    model_dict = models[n].state_dict()
    pretrained_dict = torch.load(path)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    models[n].load_state_dict(model_dict)
    models[n].eval()

while True:
    ret, frame = cap.read()

    if not ret:
        break

    frame = frame[32:989, 323:1599, :]
    frame = cv2.resize(frame, (480, 352), interpolation=cv2.INTER_LINEAR)
    frame = cv2.undistort(frame, scaled_intrinsic, dist_coeffs)
    cv2.imshow("source_video", frame)
    # ord_out.write(frame)
    cv2.waitKey(1)
    pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    input_tensor = transforms.ToTensor()(pil_img).unsqueeze(0).to(device)
    depth_features = models["encoder"](input_tensor)
    disp = models["depth"](depth_features)[("disp", 0)][0]
    disp = normalize_image(disp)
    ndarr = disp.mul(128).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    cv2.imshow("disparity", ndarr)
    disp_out.write(cv2.cvtColor(ndarr, cv2.COLOR_GRAY2BGR))
    cv2.waitKey(1)
