Newer
Older
WasteNet / src / train.py
@sato sato on 1 Mar 2022 7 KB 最初のコミット
import os
import os.path as osp
import random
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from sklearn.metrics import confusion_matrix
import seaborn as sn
from tqdm import tqdm
import argparse
from torchvision import transforms
from datasets import WasteNetDataset
from torch.utils.data import DataLoader

# 自前の関数群
from apex import amp, optimizers
from utils import *

parser = argparse.ArgumentParser()

parser.add_argument("model_name", type=str) #ex) "efficient_net_b5_1"
parser.add_argument("split_num", type=int)

args = parser.parse_args()

MODEL_NAME = args.model_name
LOGGING_DIR = "undersample+bagging"
SPLIT_NUM = args.split_num
START_SEQUENCE = 0

BATCH_SIZE = 4
IMG_WIDTH = 384
IMG_HEIGHT = 288
SHOW_INFO_FREQ = 100
ALL_SEQUENCE_NUM = 9

# 学習パラメータ系
EPOCH_NUM = 10

# apex関連
# opt_level = 'O1'

if __name__ == "__main__":
    for sequence_num in range(START_SEQUENCE, ALL_SEQUENCE_NUM):
        torch.backends.cudnn.benchmark = True

        # 再現性のためのrandom seedを設定する
        manual_seed = 999
        print("Random Seed: ", manual_seed)
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)

        # どのデバイスで実行するかをきめる
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Use device: {device}")

        # データ選ぶ部
        # data_table = pd.read_csv("../input/train.csv", header=None).values.tolist()
        # train_dset, test_dset = train_test_split_by_sequence(data_table, test_sequence_num=sequence_num, balanced_test=True)
        train_dset, test_dset = load_dsets_for_bagging(sequence_num, SPLIT_NUM, balanced_test=True)

        if MODEL_NAME.startswith("efficient"):
            model = load_efficientnet("efficientnet-b5", IMG_HEIGHT, IMG_WIDTH)
        elif MODEL_NAME.startswith("resnet"):
            model = load_resnet50()
        elif MODEL_NAME.startswith("densenet"):
            model = load_densenet161()
        elif MODEL_NAME.startswith("deit"):
            model = load_deit(IMG_HEIGHT, IMG_WIDTH)
        elif MODEL_NAME.startswith("coatnet"):
            model = load_CoAtNet()
        elif MODEL_NAME.startswith("effnetv2"):
            model = load_effnetv2()
        elif MODEL_NAME.startswith("mobile"):
            model = load_mobilenet_v2()
        else:
            raise NotImplementedError()

        model.to(device)
        optimizer = torch.optim.Adam(model.parameters())
        # model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

        # データローダー部

        if MODEL_NAME.startswith("coatnet"):
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                transforms.Resize((224, 224))
            ])
        else:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

        train_datasets = WasteNetDataset(train_dset, IMG_WIDTH, IMG_HEIGHT, transform)
        test_datasets = WasteNetDataset(test_dset, IMG_WIDTH, IMG_HEIGHT, transform)
        train_loader = DataLoader(train_datasets, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
        test_loader = DataLoader(test_datasets, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)

        # 損失関数
        criterion = nn.CrossEntropyLoss()

        # tensorboardX
        os.makedirs(osp.join(f"../logs", LOGGING_DIR), exist_ok=True)
        os.makedirs(osp.join(f"../logs", LOGGING_DIR, MODEL_NAME), exist_ok=True)

        os.makedirs(osp.join("../logs", LOGGING_DIR, MODEL_NAME, f"sequence{sequence_num}"), exist_ok=True)
        writer = SummaryWriter(osp.join("../logs", LOGGING_DIR, MODEL_NAME, f"sequence{sequence_num}"))
        classes = ("bad data", "good data")

        for epoch in range(1, EPOCH_NUM + 1):

            # train phase
            model.train()
            running_loss = 0.0
            correct_num = 0
            pred_target_num = 0

            for i, (imgs, labels) in enumerate(tqdm(train_loader)):
                model.zero_grad(set_to_none=True)

                imgs = imgs.to(device)
                labels = labels.to(device)

                pred = model(imgs)
                loss = criterion(pred, labels)
                running_loss += loss.item()

                pred_target_num += pred.shape[0]
                correct_num += torch.sum(torch.argmax(pred, dim=1) == labels).item()

                loss.backward()
                # with amp.scale_loss(loss, optimizer) as scaled_loss:
                #     scaled_loss.backward()

                optimizer.step()

                if i % SHOW_INFO_FREQ == (SHOW_INFO_FREQ - 1):
                    print("train")
                    print(f"[sequence{sequence_num} {epoch}, {i}] {running_loss}")
                    print(f"[avg] {correct_num / pred_target_num}")
                    print()

                    writer.add_scalar("train_loss", running_loss, train_loader.__len__() * (epoch - 1) + i)
                    writer.add_scalar("train_acc", correct_num / pred_target_num, train_loader.__len__() * (epoch - 1) + i)

                    running_loss = 0.0
                    correct_num = 0.0
                    pred_target_num = 0.0

            # test phase
            with torch.no_grad():
                model.eval()
                running_loss = 0.0
                correct_num = 0
                pred_target_num = 0

                y_pred = []
                y_true = []
                for i, (imgs, labels) in enumerate(tqdm(test_loader)):
                    imgs = imgs.to(device)
                    labels = labels.to(device)

                    pred = model(imgs)
                    loss = criterion(pred, labels)
                    running_loss += loss.item()

                    pred_target_num += pred.shape[0]
                    pred_index = torch.argmax(pred, dim=1)
                    correct_num += torch.sum(pred_index == labels).item()

                    y_pred.extend(pred_index.data.cpu().numpy())
                    y_true.extend(labels.data.cpu().numpy())

            print("test")
            print(f"[sequence{sequence_num} {epoch}] {running_loss}")
            print(f"[avg] {correct_num / pred_target_num}")
            print()

            writer.add_scalar("test_loss", running_loss, train_loader.__len__() * (epoch - 1))
            writer.add_scalar("test_acc", correct_num / pred_target_num, train_loader.__len__() * (epoch - 1))
            torch.save(model.state_dict(), osp.join("../logs", LOGGING_DIR, MODEL_NAME, f"sequence{sequence_num}", f"test{epoch}_acc{(correct_num / pred_target_num):.2f}_weights.pth"))

            # 混同行列の保存
            cf_matrix = confusion_matrix(y_true, y_pred)
            df_cm = pd.DataFrame(cf_matrix, index=[f"{i}(label)" for i in classes], columns=[f"{i}(predicted)" for i in classes])
            plt.figure(figsize=(12, 7))
            show_fig = sn.heatmap(df_cm, annot=True).get_figure()
            writer.add_figure("Confusion matrix", show_fig, epoch)
    del model
    torch.cuda.empty_cache()