import os
import os.path as osp
import random
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from sklearn.metrics import confusion_matrix
import seaborn as sn
from tqdm import tqdm
import argparse
from torchvision import transforms
from datasets import WasteNetDataset
from torch.utils.data import DataLoader
# 自前の関数群
from apex import amp, optimizers
from utils import *
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str) #ex) "efficient_net_b5_1"
parser.add_argument("split_num", type=int)
args = parser.parse_args()
MODEL_NAME = args.model_name
LOGGING_DIR = "undersample+bagging"
SPLIT_NUM = args.split_num
START_SEQUENCE = 0
BATCH_SIZE = 4
IMG_WIDTH = 384
IMG_HEIGHT = 288
SHOW_INFO_FREQ = 100
ALL_SEQUENCE_NUM = 9
# 学習パラメータ系
EPOCH_NUM = 10
# apex関連
# opt_level = 'O1'
if __name__ == "__main__":
for sequence_num in range(START_SEQUENCE, ALL_SEQUENCE_NUM):
torch.backends.cudnn.benchmark = True
# 再現性のためのrandom seedを設定する
manual_seed = 999
print("Random Seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
# どのデバイスで実行するかをきめる
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Use device: {device}")
# データ選ぶ部
# data_table = pd.read_csv("../input/train.csv", header=None).values.tolist()
# train_dset, test_dset = train_test_split_by_sequence(data_table, test_sequence_num=sequence_num, balanced_test=True)
train_dset, test_dset = load_dsets_for_bagging(sequence_num, SPLIT_NUM, balanced_test=True)
if MODEL_NAME.startswith("efficient"):
model = load_efficientnet("efficientnet-b5", IMG_HEIGHT, IMG_WIDTH)
elif MODEL_NAME.startswith("resnet"):
model = load_resnet50()
elif MODEL_NAME.startswith("densenet"):
model = load_densenet161()
elif MODEL_NAME.startswith("deit"):
model = load_deit(IMG_HEIGHT, IMG_WIDTH)
elif MODEL_NAME.startswith("coatnet"):
model = load_CoAtNet()
elif MODEL_NAME.startswith("effnetv2"):
model = load_effnetv2()
elif MODEL_NAME.startswith("mobile"):
model = load_mobilenet_v2()
else:
raise NotImplementedError()
model.to(device)
optimizer = torch.optim.Adam(model.parameters())
# model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# データローダー部
if MODEL_NAME.startswith("coatnet"):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.Resize((224, 224))
])
else:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_datasets = WasteNetDataset(train_dset, IMG_WIDTH, IMG_HEIGHT, transform)
test_datasets = WasteNetDataset(test_dset, IMG_WIDTH, IMG_HEIGHT, transform)
train_loader = DataLoader(train_datasets, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_datasets, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)
# 損失関数
criterion = nn.CrossEntropyLoss()
# tensorboardX
os.makedirs(osp.join(f"../logs", LOGGING_DIR), exist_ok=True)
os.makedirs(osp.join(f"../logs", LOGGING_DIR, MODEL_NAME), exist_ok=True)
os.makedirs(osp.join("../logs", LOGGING_DIR, MODEL_NAME, f"sequence{sequence_num}"), exist_ok=True)
writer = SummaryWriter(osp.join("../logs", LOGGING_DIR, MODEL_NAME, f"sequence{sequence_num}"))
classes = ("bad data", "good data")
for epoch in range(1, EPOCH_NUM + 1):
# train phase
model.train()
running_loss = 0.0
correct_num = 0
pred_target_num = 0
for i, (imgs, labels) in enumerate(tqdm(train_loader)):
model.zero_grad(set_to_none=True)
imgs = imgs.to(device)
labels = labels.to(device)
pred = model(imgs)
loss = criterion(pred, labels)
running_loss += loss.item()
pred_target_num += pred.shape[0]
correct_num += torch.sum(torch.argmax(pred, dim=1) == labels).item()
loss.backward()
# with amp.scale_loss(loss, optimizer) as scaled_loss:
# scaled_loss.backward()
optimizer.step()
if i % SHOW_INFO_FREQ == (SHOW_INFO_FREQ - 1):
print("train")
print(f"[sequence{sequence_num} {epoch}, {i}] {running_loss}")
print(f"[avg] {correct_num / pred_target_num}")
print()
writer.add_scalar("train_loss", running_loss, train_loader.__len__() * (epoch - 1) + i)
writer.add_scalar("train_acc", correct_num / pred_target_num, train_loader.__len__() * (epoch - 1) + i)
running_loss = 0.0
correct_num = 0.0
pred_target_num = 0.0
# test phase
with torch.no_grad():
model.eval()
running_loss = 0.0
correct_num = 0
pred_target_num = 0
y_pred = []
y_true = []
for i, (imgs, labels) in enumerate(tqdm(test_loader)):
imgs = imgs.to(device)
labels = labels.to(device)
pred = model(imgs)
loss = criterion(pred, labels)
running_loss += loss.item()
pred_target_num += pred.shape[0]
pred_index = torch.argmax(pred, dim=1)
correct_num += torch.sum(pred_index == labels).item()
y_pred.extend(pred_index.data.cpu().numpy())
y_true.extend(labels.data.cpu().numpy())
print("test")
print(f"[sequence{sequence_num} {epoch}] {running_loss}")
print(f"[avg] {correct_num / pred_target_num}")
print()
writer.add_scalar("test_loss", running_loss, train_loader.__len__() * (epoch - 1))
writer.add_scalar("test_acc", correct_num / pred_target_num, train_loader.__len__() * (epoch - 1))
torch.save(model.state_dict(), osp.join("../logs", LOGGING_DIR, MODEL_NAME, f"sequence{sequence_num}", f"test{epoch}_acc{(correct_num / pred_target_num):.2f}_weights.pth"))
# 混同行列の保存
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix, index=[f"{i}(label)" for i in classes], columns=[f"{i}(predicted)" for i in classes])
plt.figure(figsize=(12, 7))
show_fig = sn.heatmap(df_cm, annot=True).get_figure()
writer.add_figure("Confusion matrix", show_fig, epoch)
del model
torch.cuda.empty_cache()