Newer
Older
ConvertOrd2NBICycleGAN / predict_single_img.py
@sato sato on 1 Mar 2022 1 KB READMEの更新
import cv2
import numpy as np
import PIL.Image as image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os.path as osp

from EsoMovieConverter import EsoMovieConverter
from models import CycleGenerator

path = "single_img2/5_.png"
weight_path = r"./model_weights/G_AB_199.pth"

eso_movie_converter = EsoMovieConverter()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ord2bli_generator = CycleGenerator(3, 3, 9).to(device)
ord2bli_generator.load_state_dict(torch.load(weight_path, map_location=device))

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

eso_frame_array = eso_movie_converter(cv2.imread(path))
ord_eso_array = eso_frame_array.copy()
pil_img = image.fromarray(cv2.cvtColor(eso_frame_array, cv2.COLOR_BGR2RGB))
input_transform = torch.unsqueeze(transform(pil_img), 0).to(device)
out = ord2bli_generator(input_transform)[0]
ndarr = out.add_(1.0).mul(128).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
ndarr = cv2.cvtColor(ndarr, cv2.COLOR_RGB2BGR)
cv2.imshow("a", ndarr)
cv2.waitKey(10)
cv2.imwrite(f"single_img2/{osp.basename(path).split('.')[0]}r.png", ndarr)

'''
bli_frame_array = bli_frame_tensor.cpu()[0].detach().numpy().transpose((1, 2, 0))
bli_frame_array = np.clip((255 * bli_frame_array).astype(np.uint8), 0, 255)
cv2.imshow("bli", bli_frame_array)
cv2.waitKey(1)
'''