diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..51194ad --- /dev/null +++ b/.gitignore @@ -0,0 +1,134 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + + +.vscode/ +.mypy_cache/ +.idea/ diff --git a/EsoMovieConverter.py b/EsoMovieConverter.py new file mode 100644 index 0000000..ba6e59c --- /dev/null +++ b/EsoMovieConverter.py @@ -0,0 +1,23 @@ +import numpy as np +import cv2 + +# 食道の映像を前処理することが多いような気がするので作ったクラス +# Undistort用のparamを同ディレクトリに配置する必要あり +class EsoMovieConverter: + + def __init__(self): + self.dist_coeffs = np.load('./params/dist_coeffs.npy') + self.intrinsics_scaled = np.array([[221.8766, 0, 232.2143], [0, 217.4069, 173.3776], [0, 0, 1]]) + self.kernel = np.ones((9, 9), np.uint8) + + def __call__(self, eso_frame): + eso_frame = eso_frame[32:989, 323:1599, :] + # eliminate interrace + eso_frame = eso_frame[::2, ::2, :] + eso_frame = cv2.resize(eso_frame, (480, 352), interpolation=cv2.INTER_LINEAR) + eso_frame = cv2.undistort(eso_frame, self.intrinsics_scaled, self.dist_coeffs) + + # yuv_frame = cv2.cvtColor(eso_frame, cv2.COLOR_BGR2YUV) + # mask = cv2.dilate(np.where(200 < yuv_frame[:, :, 0], 255, 0).astype(yuv_frame.dtype), self.kernel) + # eso_frame = cv2.inpaint(eso_frame, mask, 5, cv2.INPAINT_TELEA) + return eso_frame diff --git a/README.md b/README.md index a5bc766..b268a1c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,17 @@ -ConvertOrd2NBICycleGAN -=============== +# ConvertOrd2NBICycleGAN -白色光下画像から狭帯域光画像に疑似的に変換する技術 \ No newline at end of file +学習済みモデルを使って白色光下画像を狭帯域光画像に変換する + +## 動作環境 + +python 3.8.12 +torch==1.10.0 +opencv-python==4.5.4.60 + +## 準備 + +- main.py 中の VideoCapture 内の引数を深度予測したい動画にする. + +## 稼働方法 + +`python main.py` diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..1a8c88e --- /dev/null +++ b/dataset.py @@ -0,0 +1,51 @@ +import os +import sys +import glob +import random + +import torch +from torch.utils.data import Dataset +from PIL import Image +from torchvision import transforms + +available_datasets = ['horse2zebra', 'facades', "ord2bli_dset"] +base_url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/' + +class CycleGANDataset(Dataset): + """ + Class used to read images in + """ + def __init__(self, data_root=f"{__file__}/../data", dataset_name='facades', transform=None, unaligned=True, mode='train'): + + # Check whether the specified dataset is available + if dataset_name not in available_datasets: + print(f"{dataset_name}という名前のデータセットが存在しないようです") + sys.exit(1) + # Check whether the dataset is downloaded + base_dir = os.path.abspath(data_root) + dataset_dir = os.path.join(base_dir, dataset_name) + if not os.path.exists(dataset_dir): + if not os.path.exists(base_dir): + os.makedirs(base_dir, True) + print(f"{base_url}からデータセットをダウンロードし{base_dir}に解凍してください") + sys.exit(1) + + self.transform = transform + self.unaligned = unaligned + + self.files_A = sorted(glob.glob(os.path.join(dataset_dir, f"{mode}/A/*.*"))) + self.files_B = sorted(glob.glob(os.path.join(dataset_dir, f"{mode}/B/*.*"))) + + def __getitem__(self, index): + + item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)])) + + if self.unaligned: + item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B)-1)])) + else: + item_B = self.transform(Image.open(self.files_A[index % len(self.files_B)])) + + return {'A': item_A, 'B': item_B} + + def __len__(self): + return max(len(self.files_A), len(self.files_B)) diff --git a/lr_helpers.py b/lr_helpers.py new file mode 100644 index 0000000..6dceba3 --- /dev/null +++ b/lr_helpers.py @@ -0,0 +1,4 @@ +def get_lambda_rule(opts): + def lambda_rule(epoch): + return 1.0 - max(0, epoch + opts.start_epoch - opts.decay_epoch) / float(opts.epochs - opts.decay_epoch) + return lambda_rule diff --git a/main.py b/main.py new file mode 100644 index 0000000..fe2067f --- /dev/null +++ b/main.py @@ -0,0 +1,50 @@ +import cv2 +import numpy as np +import PIL.Image as image +import torch +import torchvision.transforms as transforms +import matplotlib.pyplot as plt + +from EsoMovieConverter import EsoMovieConverter +from models import CycleGenerator + +cap = cv2.VideoCapture(r'D:\Deep_Learning\MonoDepth2\esophagus\movies\trimed\0.mp4') +eso_movie_converter = EsoMovieConverter() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +fps = int(cap.get(cv2.CAP_PROP_FPS)) +w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) +h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) +fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') +video_writer = cv2.VideoWriter('bli.mp4', fourcc, fps, (480, 352)) + +ord2bli_generator = CycleGenerator(3, 3, 9).to(device) +ord2bli_generator.load_state_dict(torch.load(r'./model_weights/G_AB_199.pth', map_location=device)) + +transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + +while True: + ret, frame = cap.read() + if not ret: + break + + eso_frame_array = eso_movie_converter(frame) + ord_eso_array = eso_frame_array.copy() + pil_img = image.fromarray(cv2.cvtColor(eso_frame_array, cv2.COLOR_BGR2RGB)) + input_transform = torch.unsqueeze(transform(pil_img), 0).to(device) + out = ord2bli_generator(input_transform)[0] + ndarr = out.add_(1.0).mul(128).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + ndarr = cv2.cvtColor(ndarr, cv2.COLOR_RGB2BGR) + cv2.imshow("a", ndarr) + cv2.waitKey(1) + video_writer.write(ndarr) + + # bli_frame_array = bli_frame_tensor.cpu()[0].detach().numpy().transpose((1, 2, 0)) + # bli_frame_array = np.clip((255 * bli_frame_array).astype(np.uint8), 0, 255) + # cv2.imshow("bli", bli_frame_array) + # cv2.waitKey(1) + + +cap.release() +video_writer.release() diff --git a/make_bli_dataset/make_dataset.py b/make_bli_dataset/make_dataset.py new file mode 100644 index 0000000..28f24dd --- /dev/null +++ b/make_bli_dataset/make_dataset.py @@ -0,0 +1,37 @@ +import torch +import torchvision.transforms as transforms +from torchvision.utils import save_image +import os +import os.path as osp +import PIL.Image as Image +from glob import glob + + +from EsoMovieConverter import EsoMovieConverter +from models import CycleGenerator + +input_imgs_dir = r"D:\Deep_Learning\MonoDepth2\esophagus\imgs" +output_imgs_dir = r"D:\Deep_Learning\MonoDepth2\esophagus_bli\imgs" +model_weight_path = r"C:\Users\Planck\PycharmProjects\endo_ord2bli_CycleGAN\model_weights\G_AB_199.pth" + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +ord2bli_generator = CycleGenerator(3, 3, 9).to(device) +ord2bli_generator.load_state_dict(torch.load(model_weight_path, map_location=device)) +transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + +sub_dir_names = glob(osp.join(input_imgs_dir, '*')) +sub_dir_names = [osp.basename(path) for path in sub_dir_names] +for name in sub_dir_names: + os.makedirs(osp.join(output_imgs_dir, name), exist_ok=True) + +imgs_path = glob(osp.join(input_imgs_dir, '*', '*.*')) + +for img_path in imgs_path: + base_name = osp.basename(img_path) + subdir_name = osp.basename(osp.dirname(img_path)) + pil_img = Image.open(img_path) + input_transform = torch.unsqueeze(transform(pil_img), 0).to(device) + fake_bli_img = ord2bli_generator(input_transform)[0] + out_path = osp.join(output_imgs_dir, subdir_name, base_name) + save_image(fake_bli_img, out_path, nrow=1, normalize=True) diff --git a/make_bli_dataset/make_dataset2.py b/make_bli_dataset/make_dataset2.py new file mode 100644 index 0000000..389efe2 --- /dev/null +++ b/make_bli_dataset/make_dataset2.py @@ -0,0 +1,37 @@ +import torch +import torchvision.transforms as transforms +from torchvision.utils import save_image +import os +import os.path as osp +import PIL.Image as Image +from glob import glob + + +from EsoMovieConverter import EsoMovieConverter +from models import CycleGenerator + +input_imgs_dir = r"D:\Deep_Learning\Endo-SfMLearner\dataset3" +output_imgs_dir = r"D:\Deep_Learning\Endo-SfMLearner\dataset_bli" +model_weight_path = r"C:\Users\Planck\PycharmProjects\endo_ord2bli_CycleGAN\model_weights\G_AB_199.pth" + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +ord2bli_generator = CycleGenerator(3, 3, 9).to(device) +ord2bli_generator.load_state_dict(torch.load(model_weight_path, map_location=device)) +transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + +sub_dir_names = glob(osp.join(input_imgs_dir, '*')) +sub_dir_names = [osp.basename(path) for path in sub_dir_names] +for name in sub_dir_names: + os.makedirs(osp.join(output_imgs_dir, name), exist_ok=True) + +imgs_path = glob(osp.join(input_imgs_dir, '*', '*.*')) + +for img_path in imgs_path: + base_name = osp.basename(img_path) + subdir_name = osp.basename(osp.dirname(img_path)) + pil_img = Image.open(img_path) + input_transform = torch.unsqueeze(transform(pil_img), 0).to(device) + fake_bli_img = ord2bli_generator(input_transform)[0] + out_path = osp.join(output_imgs_dir, subdir_name, base_name) + save_image(fake_bli_img, out_path, nrow=1, normalize=True) diff --git a/model_weights/G_AB_199.pth b/model_weights/G_AB_199.pth new file mode 100644 index 0000000..659704a --- /dev/null +++ b/model_weights/G_AB_199.pth Binary files differ diff --git a/models.py b/models.py new file mode 100644 index 0000000..649a2ee --- /dev/null +++ b/models.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def conv(in_channels, out_channels, kernel_size=4, stride=2, padding=1, instance_norm=True, relu=True, relu_slope=None, init_zero_weights=False): + """ + 畳み込み層を積み上げる。識別ネットワークや生成ネットワークの構成で使う + """ + layers = [] + conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=True) + if init_zero_weights: + conv_layer.weight.data = torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001 + else: + nn.init.normal_(conv_layer.weight.data, 0.0, 0.02) + layers.append(conv_layer) + + if instance_norm: + layers.append(nn.InstanceNorm2d(out_channels)) + + if relu: + if relu_slope: + relu_layer = nn.LeakyReLU(relu_slope, True) + else: + relu_layer = nn.ReLU(inplace=True) + layers.append(relu_layer) + return layers + +def deconv(in_channels, out_channels, kernel_size=4, stride=2, padding=1, output_padding=1, instance_norm=True, relu=True, relu_slope=None, init_zero_weights=False): + + layers = [] + deconv_layer = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=True) + if init_zero_weights: + deconv_layer.weight.data = torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.001 + else: + nn.init.normal_(deconv_layer.weight.data, 0.0, 0.02) + layers.append(deconv_layer) + + if instance_norm: + layers.append(nn.InstanceNorm2d(out_channels)) + + if relu: + if relu_slope: + relu_layer = nn.LeakyReLU(relu_slope, True) + else: + relu_layer = nn.ReLU(inplace=True) + layers.append(relu_layer) + return layers + +class ResidualBlock(nn.Module): + def __init__(self, input_features): + super(ResidualBlock, self).__init__() + + conv_layers = [ + nn.ReflectionPad2d(1), + *conv(input_features, input_features, kernel_size=3, stride=1, padding=0), + nn.ReflectionPad2d(1), + *conv(input_features, input_features, kernel_size=3, stride=1, padding=0, relu=False) + ] + self.model = nn.Sequential(*conv_layers) + + def forward(self, input_data): + return input_data + self.model(input_data) + +class CycleGenerator(nn.Module): + + def __init__(self, in_channels=3, out_channels=3, res_blocks=9): + super(CycleGenerator, self).__init__() + + # First 7 x 7 convolutional layer + layers = [ + nn.ReflectionPad2d(3), + *conv(in_channels, 64, 7, stride=1, padding=0) + ] + + # Two 3 x 3 convolutional layers + input_features = 64 + output_features = input_features * 2 + for _ in range(2): + layers += conv(input_features, output_features, 3) + input_features, output_features = output_features, output_features * 2 + + # Residual blocks + for _ in range(res_blocks): + layers += [ResidualBlock(input_features)] + + # Two 3 x 3 deconvolutional layers + output_features = input_features // 2 + for _ in range(2): + layers += deconv(input_features, output_features, 3) + input_features, output_features = output_features, output_features // 2 + + # Output layer + layers += [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_features, out_channels, 7), + nn.Tanh() + ] + self.model = nn.Sequential(*layers) + + def forward(self, real_image): + return self.model(real_image) + +class Discriminator(nn.Module): + + def __init__(self, in_channels=3, conv_dim=64): + super(Discriminator, self).__init__() + + C64 = conv(in_channels, conv_dim, instance_norm=False, relu_slope=0.2) + C128 = conv(conv_dim, conv_dim * 2, relu_slope=0.2) + C256 = conv(conv_dim * 2, conv_dim * 4, relu_slope=0.2) + C512 = conv(conv_dim * 4, conv_dim * 8, stride = 1, relu_slope=0.2) + C1 = conv(conv_dim * 8, 1, stride=1, instance_norm=False, relu=False) + + self.model = nn.Sequential( + *C64, + *C128, + *C256, + *C512, + *C1 + ) + + def forward(self, image): + return self.model(image) diff --git a/params/dist_coeffs.npy b/params/dist_coeffs.npy new file mode 100644 index 0000000..e76aaa1 --- /dev/null +++ b/params/dist_coeffs.npy Binary files differ diff --git a/params/intrinsics_scaled.npy b/params/intrinsics_scaled.npy new file mode 100644 index 0000000..3c4e6a2 --- /dev/null +++ b/params/intrinsics_scaled.npy Binary files differ diff --git a/params/intrinsics_scaled2.npy b/params/intrinsics_scaled2.npy new file mode 100644 index 0000000..4482ebe --- /dev/null +++ b/params/intrinsics_scaled2.npy Binary files differ diff --git a/predict_single_img.py b/predict_single_img.py new file mode 100644 index 0000000..3d67c8d --- /dev/null +++ b/predict_single_img.py @@ -0,0 +1,40 @@ +import cv2 +import numpy as np +import PIL.Image as image +import torch +import torchvision.transforms as transforms +import matplotlib.pyplot as plt +import os.path as osp + +from EsoMovieConverter import EsoMovieConverter +from models import CycleGenerator + +path = "single_img2/5_.png" +weight_path = r"./model_weights/G_AB_199.pth" + +eso_movie_converter = EsoMovieConverter() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +ord2bli_generator = CycleGenerator(3, 3, 9).to(device) +ord2bli_generator.load_state_dict(torch.load(weight_path, map_location=device)) + +transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + +eso_frame_array = eso_movie_converter(cv2.imread(path)) +ord_eso_array = eso_frame_array.copy() +pil_img = image.fromarray(cv2.cvtColor(eso_frame_array, cv2.COLOR_BGR2RGB)) +input_transform = torch.unsqueeze(transform(pil_img), 0).to(device) +out = ord2bli_generator(input_transform)[0] +ndarr = out.add_(1.0).mul(128).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() +ndarr = cv2.cvtColor(ndarr, cv2.COLOR_RGB2BGR) +cv2.imshow("a", ndarr) +cv2.waitKey(10) +cv2.imwrite(f"single_img2/{osp.basename(path).split('.')[0]}r.png", ndarr) + +''' +bli_frame_array = bli_frame_tensor.cpu()[0].detach().numpy().transpose((1, 2, 0)) +bli_frame_array = np.clip((255 * bli_frame_array).astype(np.uint8), 0, 255) +cv2.imshow("bli", bli_frame_array) +cv2.waitKey(1) +''' \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9d84e25 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,29 @@ +[tool.isort] +profile = "black" +line_length = 120 +skip_glob = "*/migrations/*.py" + +[tool.black] +line-length = 120 +include = '\.pyi?$' +extend-exclude = ''' +/( + | \.git + | templates + | migrations +)/ +''' + +[tool.flake8] +max-line-length = 120 +extend-ignore = "E203,W503" + + +[tool.mypy] +follow-imports = "normal" +ignore_missing_imports = true +show_column_numbers = true +pretty = false +disallow_untyped_calls = true +disallow_untyped_defs = true + diff --git a/train.py b/train.py new file mode 100644 index 0000000..4dcee66 --- /dev/null +++ b/train.py @@ -0,0 +1,209 @@ +import argparse +import os +import sys +import itertools +import math +import datetime +import time + +import torchvision.transforms as transforms +from torchvision.utils import save_image +from torchvision import datasets + +from torch.utils.data import DataLoader +import torch +import torch.nn as nn +from PIL import Image +from torch.autograd import Variable + +from models import CycleGenerator, Discriminator +from lr_helpers import get_lambda_rule +from dataset import CycleGANDataset + +torch.backends.cudnn.benchmark = True +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +load_flag = False + + +def train_loop(opts): + if opts.image_height == 128: + res_blocks = 6 + elif opts.image_height >= 256: + res_blocks = 9 + + # Create networks + G_AB = CycleGenerator(opts.a_channels, opts.b_channels, res_blocks).to(device) + G_BA = CycleGenerator(opts.b_channels, opts.a_channels, res_blocks).to(device) + D_A = Discriminator(opts.a_channels, opts.d_conv_dim).to(device) + D_B = Discriminator(opts.b_channels, opts.d_conv_dim).to(device) + + if load_flag: + G_AB.load_state_dict(torch.load(r".\checkpoints_cyclegan\ord2bli_dset\G_AB_160.pth")) + G_BA.load_state_dict(torch.load(r".\checkpoints_cyclegan\ord2bli_dset\G_BA_160.pth")) + D_A.load_state_dict(torch.load(r".\checkpoints_cyclegan\ord2bli_dset\D_A_160.pth")) + D_B.load_state_dict(torch.load(r".\checkpoints_cyclegan\ord2bli_dset\D_B_160.pth")) + + + criterion_gan = nn.MSELoss() + criterion_cycle = nn.L1Loss() + criterion_identity = nn.L1Loss() + + # Weights cycle loss and identity loss + lambda_cycle = 10 + lambda_id = 0.5 * lambda_cycle + + # Create optimizers + g_optimizer = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), + lr=opts.lr, betas=(opts.beta1, opts.beta2)) + d_a_optimizer = torch.optim.Adam(D_A.parameters(), lr=opts.lr, betas=(opts.beta1, opts.beta2)) + d_b_optimizer = torch.optim.Adam(D_B.parameters(), lr=opts.lr, betas=(opts.beta1, opts.beta2)) + + # Create learning rate update schedulers + LambdaLR = get_lambda_rule(opts) + g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(g_optimizer, lr_lambda=LambdaLR) + d_a_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(d_a_optimizer, lr_lambda=LambdaLR) + d_b_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(d_b_optimizer, lr_lambda=LambdaLR) + + # Image transformations + transform = transforms.Compose([transforms.Resize(int(opts.image_height*1.12), Image.BICUBIC), + transforms.RandomCrop((opts.image_height, opts.image_width)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]) + + train_dataloader = DataLoader(CycleGANDataset(opts.dataroot_dir, opts.dataset_name, transform), + batch_size=opts.batch_size, shuffle=True, num_workers=2, pin_memory=True) + test_dataloader = DataLoader(CycleGANDataset(opts.dataroot_dir, opts.dataset_name, transform, mode="train"), + batch_size=5, shuffle=False, num_workers=1, pin_memory=True) + + end_epoch = opts.epochs + opts.start_epoch + for epoch in range(opts.start_epoch, end_epoch): + for index, batch in enumerate(train_dataloader): + # Create adversarial target + real_A = Variable(batch['A'].to(device)) + real_B = Variable(batch['B'].to(device)) + fake_A, fake_B = G_BA(real_B), G_AB(real_A) + + # Train discriminator A + d_a_optimizer.zero_grad() + + patch_real = D_A(real_A) + loss_a_real = criterion_gan(patch_real, torch.tensor(1.0).expand_as(patch_real).to(device)) + patch_fake = D_A(fake_A) + loss_a_fake = criterion_gan(patch_fake, torch.tensor(0.0).expand_as(patch_fake).to(device)) + loss_d_a = (loss_a_real + loss_a_fake) / 2 + loss_d_a.backward() + d_a_optimizer.step() + + # Train discriminator B + d_b_optimizer.zero_grad() + + patch_real = D_B(real_B) + loss_b_real = criterion_gan(patch_real, torch.tensor(1.0).expand_as(patch_real).to(device)) + patch_fake = D_B(fake_B) + loss_b_fake = criterion_gan(patch_fake, torch.tensor(0.0).expand_as(patch_fake).to(device)) + loss_d_b = (loss_b_real + loss_b_fake) / 2 + loss_d_b.backward() + d_b_optimizer.step() + + # Train generator + + g_optimizer.zero_grad() + fake_A, fake_B = G_BA(real_B), G_AB(real_A) + reconstructed_A, reconstructed_B = G_BA(fake_B), G_AB(fake_A) + # GAN loss + patch_a = D_A(fake_A) + loss_gan_ba = criterion_gan(patch_a, torch.tensor(1.0).expand_as(patch_a).to(device)) + patch_b = D_B(fake_B) + loss_gan_ab = criterion_gan(patch_b, torch.tensor(1.0).expand_as(patch_b).to(device)) + loss_gan = (loss_gan_ab + loss_gan_ba) / 2 + + # Cycle loss + loss_cycle_a = criterion_cycle(reconstructed_A, real_A) + loss_cycle_b = criterion_cycle(reconstructed_B, real_B) + loss_cycle = (loss_cycle_a + loss_cycle_b) / 2 + + # Identity loss + loss_id_a = criterion_identity(G_BA(real_A), real_A) + loss_id_b = criterion_identity(G_AB(real_B), real_B) + loss_identity = (loss_id_a + loss_id_b) / 2 + + # Total loss + loss_g = loss_gan + lambda_cycle * loss_cycle + lambda_id * loss_identity + loss_g.backward() + g_optimizer.step() + + current_batch = epoch * len(train_dataloader) + index + sys.stdout.write(f"\r[Epoch {epoch+1}/{opts.epochs-opts.start_epoch}] [Index {index}/{len(train_dataloader)}] [D_A loss: {loss_d_a.item():.4f}] [D_B loss: {loss_d_b.item():.4f}] [G loss: adv: {loss_gan.item():.4f}, cycle: {loss_cycle.item():.4f}, identity: {loss_identity.item():.4f}]") + + if current_batch % opts.sample_every == 0: + save_sample(G_AB, G_BA, current_batch, opts, test_dataloader) + + # Update learning reate + g_lr_scheduler.step() + d_a_lr_scheduler.step() + d_b_lr_scheduler.step() + if epoch % opts.checkpoint_every == 0: + torch.save(G_AB.state_dict(), f'{opts.checkpoint_dir}/{opts.dataset_name}/G_AB_{epoch}.pth') + torch.save(G_BA.state_dict(), f'{opts.checkpoint_dir}/{opts.dataset_name}/G_BA_{epoch}.pth') + torch.save(D_A.state_dict(), f'{opts.checkpoint_dir}/{opts.dataset_name}/D_A_{epoch}.pth') + torch.save(D_B.state_dict(), f'{opts.checkpoint_dir}/{opts.dataset_name}/D_B_{epoch}.pth') + + +def save_sample(G_AB, G_BA, batch, opts, test_dataloader): + images = next(iter(test_dataloader)) + real_A = Variable(images['A'].to(device)) + real_B = Variable(images['B'].to(device)) + fake_A = G_BA(real_B) + fake_B = G_AB(real_A) + reconstructed_A = G_BA(fake_B) + reconstructed_B = G_AB(fake_A) + image_sample = torch.cat((real_A.data, fake_B.data, + real_B.data, fake_A.data, + reconstructed_A.data, reconstructed_B.data), 0) + save_image(image_sample, f"{opts.sample_dir}/{opts.dataset_name}/{batch}.png", nrow=5, normalize=True) + + +def create_parser(): + + parser = argparse.ArgumentParser() + + # モデル用ハイパーパラメータ + parser.add_argument('--image_height', type=int, default=300, help='画像の高さ.') + parser.add_argument('--image_width', type=int, default=400, help='画像の広さ.') + parser.add_argument('--a_channels', type=int, default=3, help='A類画像のChannels数.') + parser.add_argument('--b_channels', type=int, default=3, help='B類画像のChannels数.') + parser.add_argument('--d_conv_dim', type=int, default=64) + + # トレーニング用ハイパーパラメータ + parser.add_argument('--dataset_name', type=str, default='ord2bli', help='使用するデータセット.') + parser.add_argument('--epochs', type=int, default=200, help='Epochの数.') + parser.add_argument('--start_epoch', type=int, default=0, help='実行開始のEpoch数.') + parser.add_argument('--decay_epoch', type=int, default=100, help='lr decayを実行し始めるEpoch数.') + parser.add_argument('--batch_size', type=int, default=1, help='一つのBatchに含まれる画像の数.') + parser.add_argument('--num_workers', type=int, default=0, help='Dataloaderに使われるスレッド数.') + parser.add_argument('--lr', type=float, default=0.0002, help='学習率(defaultは0.0002).') + parser.add_argument('--beta1', type=float, default=0.5, help='Adamオプチマイザーに使われるハイパーパラメータ.') + parser.add_argument('--beta2', type=float, default=0.999, help='Adamオプチマイザーに使われるハイパーパラメータ.') + parser.add_argument('--n_cpu', type=int, default=4, help='batchを生成するときに使用するスレッド数.') + parser.add_argument('--gpu_id', type=int, default=0, help='使用するGPUのID.') + + # サンプルやチェックポイントをとる頻度と場所 + parser.add_argument('--dataroot_dir', type=str, default='../data/') + parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_cyclegan') + parser.add_argument('--sample_dir', type=str, default='samples_cyclegan') + parser.add_argument('--load', type=str, default=None) + parser.add_argument('--log_step', type=int , default=20) + parser.add_argument('--sample_every', type=int , default=1000, help='サンプルをとる頻度、batch単位.') + parser.add_argument('--checkpoint_every', type=int , default=20, help='Check pointをとる頻度、epoch単位.') + return parser + + +if __name__ == '__main__': + + parser = create_parser() + opts = parser.parse_args() + os.makedirs(f"{opts.sample_dir}/{opts.dataset_name}", exist_ok=True) + os.makedirs(f"{opts.checkpoint_dir}/{opts.dataset_name}", exist_ok=True) + + train_loop(opts)