#pix2pixの学習を行なうプログラム
#!/usr/bin/env python

# python train_facade.py -g 0 -i ./facade/base --out result_facade --snapshot_interval 10000

from __future__ import print_function
import argparse
import os

import chainer
from chainer import training
from chainer.training import extensions
from chainer import serializers

from net import Discriminator
from net import Encoder
from net import Decoder
from updater import PicUpdater

from img_dataset import ImgDataset
from pic_visualizer import out_image

# dataset paths
#学習に使う画像のある場所
#実際の舌画像がある場所
DATASET_SRC = "D:/test13/Sample16"

#二値化した舌画像がある場所
DATASET_DST = "D:/test13/Label16"

#パラメータの保存場所
SAVE_DIR = './pix2pix_param2'
def main():
    """#各種設定"""
    #入出力画像サイズ
    w_img = 256
    h_img = 256
    #コマンドライン上で--batchsize ○○のように打つか
    #defaltの値を直接書き換える
    #適宜書き換える場所
    #１　バッチサイズ
    #２　エポック数
    #３　GPUの番号(負の番号でＵＰＵを利用※時間がかかりすぎるため非推奨)
    #４　実際の舌画像がある場所(上記のDATASET_SRCを書き換え)
    #５　二値化した舌画像がある場所(上記のDATASET_DSTを書き換え)
    #６　学習したパラメータの保存場所
    #７　パラメータの保存間隔

    parser = argparse.ArgumentParser(description='chainer implementation of pix2pix')
    #1
    parser.add_argument('--batchsize', '-b', type=int, default=4,
                        help='Number of images in each mini-batch')
    #2
    parser.add_argument('--epoch', '-e', type=int, default=4000,
                        help='Number of sweeps over the dataset to train')
    #3
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')
    #4
    parser.add_argument('--data_src', '-s', default=DATASET_SRC,
                        help='Directory of image files.')
    #5
    parser.add_argument('--data_dst', '-d', default=DATASET_DST,
                        help='Directory of ground truth image files.')
    #6
    parser.add_argument('--out', '-o', default=SAVE_DIR,
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--seed', type=int, default=0,
                        help='Random seed')
    #7
    parser.add_argument('--snapshot_interval', type=int, default=5000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval', type=int, default=5000,
                        help='Interval of displaying log to console')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Set up a neural network to train
    enc = Encoder(in_ch=3)
    dec = Decoder(out_ch=3)
    dis = Discriminator(in_ch=3, out_ch=3)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
        enc.to_gpu()  # Copy the model to the GPU
        dec.to_gpu()
        dis.to_gpu()

    # Setup an optimizer
    def make_optimizer(model, alpha=0.0002, beta1=0.5):
        optimizer = chainer.optimizers.Adam(alpha=alpha, beta1=beta1)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec')
        return optimizer
    opt_enc = make_optimizer(enc)
    opt_dec = make_optimizer(dec)
    opt_dis = make_optimizer(dis)

    train_d = ImgDataset(args.data_src, args.data_dst, data_range=(0,0.9))
    test_d = ImgDataset(args.data_src, args.data_dst, data_range=(0.9,1))
    #train_iter = chainer.iterators.MultiprocessIterator(train_d, args.batchsize, n_processes=4)
    #test_iter = chainer.iterators.MultiprocessIterator(test_d, args.batchsize, n_processes=4)
    train_iter = chainer.iterators.SerialIterator(train_d, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test_d, args.batchsize)

    # Set up a trainer
    updater = PicUpdater(
        models=(enc, dec, dis),
        iterator={
            'main': train_iter,
            'test': test_iter},
        optimizer={
            'enc': opt_enc, 'dec': opt_dec,
            'dis': opt_dis},
        device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    snapshot_interval = (args.snapshot_interval, 'iteration')
    display_interval = (args.display_interval, 'iteration')
    trainer.extend(extensions.snapshot(
        filename='snapshot_iter_{.updater.iteration}.npz'),
                   trigger=snapshot_interval)
    #イテレータごとにパラメータの保存
    trainer.extend(extensions.snapshot_object(
        enc, 'enc_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dec, 'dec_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_iter_{.updater.iteration}.npz'), trigger=snapshot_interval)
    trainer.extend(extensions.LogReport(trigger=display_interval))
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'enc/loss', 'dec/loss', 'dis/loss',
    ]), trigger=display_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))
    trainer.extend(
        out_image(
            updater, enc, dec,
            1, 1, args.seed, args.out, args.gpu,
            w_img, h_img),
        trigger=snapshot_interval)
    #最終的な学習結果を保存
    trainer.extend(extensions.snapshot_object(
        enc, 'enc_epoch_{.updater.epoch}.npz'),  trigger=(args.epoch,'epoch'))
    trainer.extend(extensions.snapshot_object(
        dec, 'dec_epoch_{.updater.epoch}.npz'),  trigger=(args.epoch,'epoch'))
    trainer.extend(extensions.snapshot_object(
        dis, 'dis_epoch_{.updater.epoch}.npz'),  trigger=(args.epoch,'epoch'))

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()

if __name__ == '__main__':
    main()
