#pix2pixの学習を行なうプログラム
#!/usr/bin/env python
# python train_facade.py -g 0 -i ./facade/base --out result_facade --snapshot_interval 10000
from __future__ import print_function
import argparse
import os
import chainer
from chainer import training
from chainer.training import extensions
from chainer import serializers
from net import Discriminator
from net import Encoder
from net import Decoder
from updater import PicUpdater
from img_dataset import ImgDataset
from pic_visualizer import out_image
# dataset paths
#学習に使う画像のある場所
#実際の舌画像がある場所
DATASET_SRC = "D:/test13/Sample16"
#二値化した舌画像がある場所
DATASET_DST = "D:/test13/Label16"
#パラメータの保存場所
SAVE_DIR = './pix2pix_param2'
def main():
"""#各種設定"""
#入出力画像サイズ
w_img = 256
h_img = 256
#コマンドライン上で--batchsize ○○のように打つか
#defaltの値を直接書き換える
#適宜書き換える場所
#1 バッチサイズ
#2 エポック数
#3 GPUの番号(負の番号でUPUを利用※時間がかかりすぎるため非推奨)
#4 実際の舌画像がある場所(上記のDATASET_SRCを書き換え)
#5 二値化した舌画像がある場所(上記のDATASET_DSTを書き換え)
#6 学習したパラメータの保存場所
#7 パラメータの保存間隔
parser = argparse.ArgumentParser(description='chainer implementation of pix2pix')
#1
parser.add_argument('--batchsize', '-b', type=int, default=4,
help='Number of images in each mini-batch')
#2
parser.add_argument('--epoch', '-e', type=int, default=4000,
help='Number of sweeps over the dataset to train')
#3
parser.add_argument('--gpu', '-g', type=int, default=0,
help='GPU ID (negative value indicates CPU)')
#4
parser.add_argument('--data_src', '-s', default=DATASET_SRC,
help='Directory of image files.')
#5
parser.add_argument('--data_dst', '-d', default=DATASET_DST,
help='Directory of ground truth image files.')
#6
parser.add_argument('--out', '-o', default=SAVE_DIR,
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--seed', type=int, default=0,
help='Random seed')
#7
parser.add_argument('--snapshot_interval', type=int, default=5000,
help='Interval of snapshot')
parser.add_argument('--display_interval', type=int, default=5000,
help='Interval of displaying log to console')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')
# Set up a neural network to train
enc = Encoder(in_ch=3)
dec = Decoder(out_ch=3)
dis = Discriminator(in_ch=3, out_ch=3)
if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current
enc.to_gpu() # Copy the model to the GPU
dec.to_gpu()
dis.to_gpu()
# Setup an optimizer
def make_optimizer(model, alpha=0.0002, beta1=0.5):
optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec')
return optimizer
opt_enc = make_optimizer(enc)
opt_dec = make_optimizer(dec)
opt_dis = make_optimizer(dis)
train_d = ImgDataset(args.data_src, args.data_dst, data_range=(0,0.9))
test_d = ImgDataset(args.data_src, args.data_dst, data_range=(0.9,1))
#train_iter = chainer.iterators.MultiprocessIterator(train_d, args.batchsize, n_processes=4)
#test_iter = chainer.iterators.MultiprocessIterator(test_d, args.batchsize, n_processes=4)
train_iter = chainer.iterators.SerialIterator(train_d, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test_d, args.batchsize)
# Set up a trainer
updater = PicUpdater(
models=(enc, dec, dis),
iterator={
'main': train_iter,
'test': test_iter},
optimizer={
'enc': opt_enc, 'dec': opt_dec,
'dis': opt_dis},
device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
snapshot_interval = (args.snapshot_interval, 'iteration')
display_interval = (args.display_interval, 'iteration')
trainer.extend(extensions.snapshot(
filename='snapshot_iter_{.updater.iteration}.npz'),
trigger=snapshot_interval)
#イテレータごとにパラメータの保存
trainer.extend(extensions.snapshot_object(
enc, 'enc_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
dec, 'dec_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
dis, 'dis_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
trainer.extend(extensions.LogReport(trigger=display_interval))
trainer.extend(extensions.PrintReport([
'epoch', 'iteration', 'enc/loss', 'dec/loss', 'dis/loss',
]), trigger=display_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.extend(
out_image(
updater, enc, dec,
1, 1, args.seed, args.out, args.gpu,
w_img, h_img),
trigger=snapshot_interval)
#最終的な学習結果を保存
trainer.extend(extensions.snapshot_object(
enc, 'enc_epoch_{.updater.epoch}.npz'), trigger=(args.epoch,'epoch'))
trainer.extend(extensions.snapshot_object(
dec, 'dec_epoch_{.updater.epoch}.npz'), trigger=(args.epoch,'epoch'))
trainer.extend(extensions.snapshot_object(
dis, 'dis_epoch_{.updater.epoch}.npz'), trigger=(args.epoch,'epoch'))
if args.resume:
# Resume from a snapshot
chainer.serializers.load_npz(args.resume, trainer)
# Run the training
trainer.run()
if __name__ == '__main__':
main()