import os
import os.path as osp
import random
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from sklearn.metrics import confusion_matrix
import seaborn as sn
from tqdm import tqdm
import argparse
from torchvision import transforms
from datasets import WasteNetDataset
from torch.utils.data import DataLoader

# 自前の関数群
from apex import amp, optimizers
from utils import *

parser = argparse.ArgumentParser()

parser.add_argument("model_name", type=str) #ex) "efficient_net_b5_1"
parser.add_argument("split_num", type=int)

args = parser.parse_args()

MODEL_NAME = args.model_name
LOGGING_DIR = "undersample+bagging"
SPLIT_NUM = args.split_num
START_SEQUENCE = 0

BATCH_SIZE = 4
IMG_WIDTH = 384
IMG_HEIGHT = 288
SHOW_INFO_FREQ = 100
ALL_SEQUENCE_NUM = 9

# 学習パラメータ系
EPOCH_NUM = 10

# apex関連
# opt_level = 'O1'

if __name__ == "__main__":
    for sequence_num in range(START_SEQUENCE, ALL_SEQUENCE_NUM):
        torch.backends.cudnn.benchmark = True

        # 再現性のためのrandom seedを設定する
        manual_seed = 999
        print("Random Seed: ", manual_seed)
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)

        # どのデバイスで実行するかをきめる
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Use device: {device}")

        # データ選ぶ部
        # data_table = pd.read_csv("../input/train.csv", header=None).values.tolist()
        # train_dset, test_dset = train_test_split_by_sequence(data_table, test_sequence_num=sequence_num, balanced_test=True)
        train_dset, test_dset = load_dsets_for_bagging(sequence_num, SPLIT_NUM, balanced_test=True)

        if MODEL_NAME.startswith("efficient"):
            model = load_efficientnet("efficientnet-b5", IMG_HEIGHT, IMG_WIDTH)
        elif MODEL_NAME.startswith("resnet"):
            model = load_resnet50()
        elif MODEL_NAME.startswith("densenet"):
            model = load_densenet161()
        elif MODEL_NAME.startswith("deit"):
            model = load_deit(IMG_HEIGHT, IMG_WIDTH)
        elif MODEL_NAME.startswith("coatnet"):
            model = load_CoAtNet()
        elif MODEL_NAME.startswith("effnetv2"):
            model = load_effnetv2()
        elif MODEL_NAME.startswith("mobile"):
            model = load_mobilenet_v2()
        else:
            raise NotImplementedError()

        model.to(device)
        optimizer = torch.optim.Adam(model.parameters())
        # model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

        # データローダー部

        if MODEL_NAME.startswith("coatnet"):
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                transforms.Resize((224, 224))
            ])
        else:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

        train_datasets = WasteNetDataset(train_dset, IMG_WIDTH, IMG_HEIGHT, transform)
        test_datasets = WasteNetDataset(test_dset, IMG_WIDTH, IMG_HEIGHT, transform)
        train_loader = DataLoader(train_datasets, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
        test_loader = DataLoader(test_datasets, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)

        # 損失関数
        criterion = nn.CrossEntropyLoss()

        # tensorboardX
        os.makedirs(osp.join(f"../logs", LOGGING_DIR), exist_ok=True)
        os.makedirs(osp.join(f"../logs", LOGGING_DIR, MODEL_NAME), exist_ok=True)

        os.makedirs(osp.join("../logs", LOGGING_DIR, MODEL_NAME, f"sequence{sequence_num}"), exist_ok=True)
        writer = SummaryWriter(osp.join("../logs", LOGGING_DIR, MODEL_NAME, f"sequence{sequence_num}"))
        classes = ("bad data", "good data")

        for epoch in range(1, EPOCH_NUM + 1):

            # train phase
            model.train()
            running_loss = 0.0
            correct_num = 0
            pred_target_num = 0

            for i, (imgs, labels) in enumerate(tqdm(train_loader)):
                model.zero_grad(set_to_none=True)

                imgs = imgs.to(device)
                labels = labels.to(device)

                pred = model(imgs)
                loss = criterion(pred, labels)
                running_loss += loss.item()

                pred_target_num += pred.shape[0]
                correct_num += torch.sum(torch.argmax(pred, dim=1) == labels).item()

                loss.backward()
                # with amp.scale_loss(loss, optimizer) as scaled_loss:
                #     scaled_loss.backward()

                optimizer.step()

                if i % SHOW_INFO_FREQ == (SHOW_INFO_FREQ - 1):
                    print("train")
                    print(f"[sequence{sequence_num} {epoch}, {i}] {running_loss}")
                    print(f"[avg] {correct_num / pred_target_num}")
                    print()

                    writer.add_scalar("train_loss", running_loss, train_loader.__len__() * (epoch - 1) + i)
                    writer.add_scalar("train_acc", correct_num / pred_target_num, train_loader.__len__() * (epoch - 1) + i)

                    running_loss = 0.0
                    correct_num = 0.0
                    pred_target_num = 0.0

            # test phase
            with torch.no_grad():
                model.eval()
                running_loss = 0.0
                correct_num = 0
                pred_target_num = 0

                y_pred = []
                y_true = []
                for i, (imgs, labels) in enumerate(tqdm(test_loader)):
                    imgs = imgs.to(device)
                    labels = labels.to(device)

                    pred = model(imgs)
                    loss = criterion(pred, labels)
                    running_loss += loss.item()

                    pred_target_num += pred.shape[0]
                    pred_index = torch.argmax(pred, dim=1)
                    correct_num += torch.sum(pred_index == labels).item()

                    y_pred.extend(pred_index.data.cpu().numpy())
                    y_true.extend(labels.data.cpu().numpy())

            print("test")
            print(f"[sequence{sequence_num} {epoch}] {running_loss}")
            print(f"[avg] {correct_num / pred_target_num}")
            print()

            writer.add_scalar("test_loss", running_loss, train_loader.__len__() * (epoch - 1))
            writer.add_scalar("test_acc", correct_num / pred_target_num, train_loader.__len__() * (epoch - 1))
            torch.save(model.state_dict(), osp.join("../logs", LOGGING_DIR, MODEL_NAME, f"sequence{sequence_num}", f"test{epoch}_acc{(correct_num / pred_target_num):.2f}_weights.pth"))

            # 混同行列の保存
            cf_matrix = confusion_matrix(y_true, y_pred)
            df_cm = pd.DataFrame(cf_matrix, index=[f"{i}(label)" for i in classes], columns=[f"{i}(predicted)" for i in classes])
            plt.figure(figsize=(12, 7))
            show_fig = sn.heatmap(df_cm, annot=True).get_figure()
            writer.add_figure("Confusion matrix", show_fig, epoch)
    del model
    torch.cuda.empty_cache()
    