import argparse
import os
import sys
import itertools
import math
import datetime
import time
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from PIL import Image
from torch.autograd import Variable
from models import CycleGenerator, Discriminator
from lr_helpers import get_lambda_rule
from dataset import CycleGANDataset
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_flag = False
def train_loop(opts):
if opts.image_height == 128:
res_blocks = 6
elif opts.image_height >= 256:
res_blocks = 9
# Create networks
G_AB = CycleGenerator(opts.a_channels, opts.b_channels, res_blocks).to(device)
G_BA = CycleGenerator(opts.b_channels, opts.a_channels, res_blocks).to(device)
D_A = Discriminator(opts.a_channels, opts.d_conv_dim).to(device)
D_B = Discriminator(opts.b_channels, opts.d_conv_dim).to(device)
if load_flag:
G_AB.load_state_dict(torch.load(r".\checkpoints_cyclegan\ord2bli_dset\G_AB_160.pth"))
G_BA.load_state_dict(torch.load(r".\checkpoints_cyclegan\ord2bli_dset\G_BA_160.pth"))
D_A.load_state_dict(torch.load(r".\checkpoints_cyclegan\ord2bli_dset\D_A_160.pth"))
D_B.load_state_dict(torch.load(r".\checkpoints_cyclegan\ord2bli_dset\D_B_160.pth"))
criterion_gan = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
# Weights cycle loss and identity loss
lambda_cycle = 10
lambda_id = 0.5 * lambda_cycle
# Create optimizers
g_optimizer = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()),
lr=opts.lr, betas=(opts.beta1, opts.beta2))
d_a_optimizer = torch.optim.Adam(D_A.parameters(), lr=opts.lr, betas=(opts.beta1, opts.beta2))
d_b_optimizer = torch.optim.Adam(D_B.parameters(), lr=opts.lr, betas=(opts.beta1, opts.beta2))
# Create learning rate update schedulers
LambdaLR = get_lambda_rule(opts)
g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(g_optimizer, lr_lambda=LambdaLR)
d_a_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(d_a_optimizer, lr_lambda=LambdaLR)
d_b_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(d_b_optimizer, lr_lambda=LambdaLR)
# Image transformations
transform = transforms.Compose([transforms.Resize(int(opts.image_height*1.12), Image.BICUBIC),
transforms.RandomCrop((opts.image_height, opts.image_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
train_dataloader = DataLoader(CycleGANDataset(opts.dataroot_dir, opts.dataset_name, transform),
batch_size=opts.batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_dataloader = DataLoader(CycleGANDataset(opts.dataroot_dir, opts.dataset_name, transform, mode="train"),
batch_size=5, shuffle=False, num_workers=1, pin_memory=True)
end_epoch = opts.epochs + opts.start_epoch
for epoch in range(opts.start_epoch, end_epoch):
for index, batch in enumerate(train_dataloader):
# Create adversarial target
real_A = Variable(batch['A'].to(device))
real_B = Variable(batch['B'].to(device))
fake_A, fake_B = G_BA(real_B), G_AB(real_A)
# Train discriminator A
d_a_optimizer.zero_grad()
patch_real = D_A(real_A)
loss_a_real = criterion_gan(patch_real, torch.tensor(1.0).expand_as(patch_real).to(device))
patch_fake = D_A(fake_A)
loss_a_fake = criterion_gan(patch_fake, torch.tensor(0.0).expand_as(patch_fake).to(device))
loss_d_a = (loss_a_real + loss_a_fake) / 2
loss_d_a.backward()
d_a_optimizer.step()
# Train discriminator B
d_b_optimizer.zero_grad()
patch_real = D_B(real_B)
loss_b_real = criterion_gan(patch_real, torch.tensor(1.0).expand_as(patch_real).to(device))
patch_fake = D_B(fake_B)
loss_b_fake = criterion_gan(patch_fake, torch.tensor(0.0).expand_as(patch_fake).to(device))
loss_d_b = (loss_b_real + loss_b_fake) / 2
loss_d_b.backward()
d_b_optimizer.step()
# Train generator
g_optimizer.zero_grad()
fake_A, fake_B = G_BA(real_B), G_AB(real_A)
reconstructed_A, reconstructed_B = G_BA(fake_B), G_AB(fake_A)
# GAN loss
patch_a = D_A(fake_A)
loss_gan_ba = criterion_gan(patch_a, torch.tensor(1.0).expand_as(patch_a).to(device))
patch_b = D_B(fake_B)
loss_gan_ab = criterion_gan(patch_b, torch.tensor(1.0).expand_as(patch_b).to(device))
loss_gan = (loss_gan_ab + loss_gan_ba) / 2
# Cycle loss
loss_cycle_a = criterion_cycle(reconstructed_A, real_A)
loss_cycle_b = criterion_cycle(reconstructed_B, real_B)
loss_cycle = (loss_cycle_a + loss_cycle_b) / 2
# Identity loss
loss_id_a = criterion_identity(G_BA(real_A), real_A)
loss_id_b = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_a + loss_id_b) / 2
# Total loss
loss_g = loss_gan + lambda_cycle * loss_cycle + lambda_id * loss_identity
loss_g.backward()
g_optimizer.step()
current_batch = epoch * len(train_dataloader) + index
sys.stdout.write(f"\r[Epoch {epoch+1}/{opts.epochs-opts.start_epoch}] [Index {index}/{len(train_dataloader)}] [D_A loss: {loss_d_a.item():.4f}] [D_B loss: {loss_d_b.item():.4f}] [G loss: adv: {loss_gan.item():.4f}, cycle: {loss_cycle.item():.4f}, identity: {loss_identity.item():.4f}]")
if current_batch % opts.sample_every == 0:
save_sample(G_AB, G_BA, current_batch, opts, test_dataloader)
# Update learning reate
g_lr_scheduler.step()
d_a_lr_scheduler.step()
d_b_lr_scheduler.step()
if epoch % opts.checkpoint_every == 0:
torch.save(G_AB.state_dict(), f'{opts.checkpoint_dir}/{opts.dataset_name}/G_AB_{epoch}.pth')
torch.save(G_BA.state_dict(), f'{opts.checkpoint_dir}/{opts.dataset_name}/G_BA_{epoch}.pth')
torch.save(D_A.state_dict(), f'{opts.checkpoint_dir}/{opts.dataset_name}/D_A_{epoch}.pth')
torch.save(D_B.state_dict(), f'{opts.checkpoint_dir}/{opts.dataset_name}/D_B_{epoch}.pth')
def save_sample(G_AB, G_BA, batch, opts, test_dataloader):
images = next(iter(test_dataloader))
real_A = Variable(images['A'].to(device))
real_B = Variable(images['B'].to(device))
fake_A = G_BA(real_B)
fake_B = G_AB(real_A)
reconstructed_A = G_BA(fake_B)
reconstructed_B = G_AB(fake_A)
image_sample = torch.cat((real_A.data, fake_B.data,
real_B.data, fake_A.data,
reconstructed_A.data, reconstructed_B.data), 0)
save_image(image_sample, f"{opts.sample_dir}/{opts.dataset_name}/{batch}.png", nrow=5, normalize=True)
def create_parser():
parser = argparse.ArgumentParser()
# モデル用ハイパーパラメータ
parser.add_argument('--image_height', type=int, default=300, help='画像の高さ.')
parser.add_argument('--image_width', type=int, default=400, help='画像の広さ.')
parser.add_argument('--a_channels', type=int, default=3, help='A類画像のChannels数.')
parser.add_argument('--b_channels', type=int, default=3, help='B類画像のChannels数.')
parser.add_argument('--d_conv_dim', type=int, default=64)
# トレーニング用ハイパーパラメータ
parser.add_argument('--dataset_name', type=str, default='ord2bli', help='使用するデータセット.')
parser.add_argument('--epochs', type=int, default=200, help='Epochの数.')
parser.add_argument('--start_epoch', type=int, default=0, help='実行開始のEpoch数.')
parser.add_argument('--decay_epoch', type=int, default=100, help='lr decayを実行し始めるEpoch数.')
parser.add_argument('--batch_size', type=int, default=1, help='一つのBatchに含まれる画像の数.')
parser.add_argument('--num_workers', type=int, default=0, help='Dataloaderに使われるスレッド数.')
parser.add_argument('--lr', type=float, default=0.0002, help='学習率(defaultは0.0002).')
parser.add_argument('--beta1', type=float, default=0.5, help='Adamオプチマイザーに使われるハイパーパラメータ.')
parser.add_argument('--beta2', type=float, default=0.999, help='Adamオプチマイザーに使われるハイパーパラメータ.')
parser.add_argument('--n_cpu', type=int, default=4, help='batchを生成するときに使用するスレッド数.')
parser.add_argument('--gpu_id', type=int, default=0, help='使用するGPUのID.')
# サンプルやチェックポイントをとる頻度と場所
parser.add_argument('--dataroot_dir', type=str, default='../data/')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_cyclegan')
parser.add_argument('--sample_dir', type=str, default='samples_cyclegan')
parser.add_argument('--load', type=str, default=None)
parser.add_argument('--log_step', type=int , default=20)
parser.add_argument('--sample_every', type=int , default=1000, help='サンプルをとる頻度、batch単位.')
parser.add_argument('--checkpoint_every', type=int , default=20, help='Check pointをとる頻度、epoch単位.')
return parser
if __name__ == '__main__':
parser = create_parser()
opts = parser.parse_args()
os.makedirs(f"{opts.sample_dir}/{opts.dataset_name}", exist_ok=True)
os.makedirs(f"{opts.checkpoint_dir}/{opts.dataset_name}", exist_ok=True)
train_loop(opts)