import cv2
import numpy as np
import PIL.Image as image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from EsoMovieConverter import EsoMovieConverter
from models import CycleGenerator
cap = cv2.VideoCapture(r'D:\Deep_Learning\MonoDepth2\esophagus\movies\trimed\0.mp4')
eso_movie_converter = EsoMovieConverter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
fps = int(cap.get(cv2.CAP_PROP_FPS))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
video_writer = cv2.VideoWriter('bli.mp4', fourcc, fps, (480, 352))
ord2bli_generator = CycleGenerator(3, 3, 9).to(device)
ord2bli_generator.load_state_dict(torch.load(r'./model_weights/G_AB_199.pth', map_location=device))
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
while True:
ret, frame = cap.read()
if not ret:
break
eso_frame_array = eso_movie_converter(frame)
ord_eso_array = eso_frame_array.copy()
pil_img = image.fromarray(cv2.cvtColor(eso_frame_array, cv2.COLOR_BGR2RGB))
input_transform = torch.unsqueeze(transform(pil_img), 0).to(device)
out = ord2bli_generator(input_transform)[0]
ndarr = out.add_(1.0).mul(128).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
ndarr = cv2.cvtColor(ndarr, cv2.COLOR_RGB2BGR)
cv2.imshow("a", ndarr)
cv2.waitKey(1)
video_writer.write(ndarr)
# bli_frame_array = bli_frame_tensor.cpu()[0].detach().numpy().transpose((1, 2, 0))
# bli_frame_array = np.clip((255 * bli_frame_array).astype(np.uint8), 0, 255)
# cv2.imshow("bli", bli_frame_array)
# cv2.waitKey(1)
cap.release()
video_writer.release()