Newer
Older
ConvertOrd2NBICycleGAN / make_bli_dataset / make_dataset2.py
@sato sato on 1 Mar 2022 1 KB READMEの更新
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os
import os.path as osp
import PIL.Image as Image
from glob import glob


from EsoMovieConverter import EsoMovieConverter
from models import CycleGenerator

input_imgs_dir = r"D:\Deep_Learning\Endo-SfMLearner\dataset3"
output_imgs_dir = r"D:\Deep_Learning\Endo-SfMLearner\dataset_bli"
model_weight_path = r"C:\Users\Planck\PycharmProjects\endo_ord2bli_CycleGAN\model_weights\G_AB_199.pth"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ord2bli_generator = CycleGenerator(3, 3, 9).to(device)
ord2bli_generator.load_state_dict(torch.load(model_weight_path, map_location=device))
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

sub_dir_names = glob(osp.join(input_imgs_dir, '*'))
sub_dir_names = [osp.basename(path) for path in sub_dir_names]
for name in sub_dir_names:
    os.makedirs(osp.join(output_imgs_dir, name), exist_ok=True)

imgs_path = glob(osp.join(input_imgs_dir, '*', '*.*'))

for img_path in imgs_path:
    base_name = osp.basename(img_path)
    subdir_name = osp.basename(osp.dirname(img_path))
    pil_img = Image.open(img_path)
    input_transform = torch.unsqueeze(transform(pil_img), 0).to(device)
    fake_bli_img = ord2bli_generator(input_transform)[0]
    out_path = osp.join(output_imgs_dir, subdir_name, base_name)
    save_image(fake_bli_img, out_path, nrow=1, normalize=True)