import cv2
import numpy as np
import PIL.Image as image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os.path as osp
from EsoMovieConverter import EsoMovieConverter
from models import CycleGenerator
path = "single_img2/5_.png"
weight_path = r"./model_weights/G_AB_199.pth"
eso_movie_converter = EsoMovieConverter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ord2bli_generator = CycleGenerator(3, 3, 9).to(device)
ord2bli_generator.load_state_dict(torch.load(weight_path, map_location=device))
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
eso_frame_array = eso_movie_converter(cv2.imread(path))
ord_eso_array = eso_frame_array.copy()
pil_img = image.fromarray(cv2.cvtColor(eso_frame_array, cv2.COLOR_BGR2RGB))
input_transform = torch.unsqueeze(transform(pil_img), 0).to(device)
out = ord2bli_generator(input_transform)[0]
ndarr = out.add_(1.0).mul(128).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
ndarr = cv2.cvtColor(ndarr, cv2.COLOR_RGB2BGR)
cv2.imshow("a", ndarr)
cv2.waitKey(10)
cv2.imwrite(f"single_img2/{osp.basename(path).split('.')[0]}r.png", ndarr)
'''
bli_frame_array = bli_frame_tensor.cpu()[0].detach().numpy().transpose((1, 2, 0))
bli_frame_array = np.clip((255 * bli_frame_array).astype(np.uint8), 0, 255)
cv2.imshow("bli", bli_frame_array)
cv2.waitKey(1)
'''