Newer
Older
EBUS_Classify / ebus_movie_classify.py
import glob
import os.path as osp
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import csv
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms
import sys
import datetime

# 入力画像の前処理クラス
# 訓練時と推論時で処理を変える
class ImageTransform():
    """
    画像の前処理クラス.訓練時と推論時で処理が異なる.
    データ前処理:画像のリサイズ,色の標準化.
    訓練時データ拡張:RandomResezedCropとRandomHorizontalFlip
    
    Attributes
    ----------
    resize : int
        リサイズの大きさ
    mean : (R, G, B)
        各色チャネルの平均値
    std : (R, G, B)
        各色チャネルの標準偏差
    """

    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(
                    resize, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(), # Torchテンソルに変換
                transforms.Normalize(mean, std) # 色の標準化
            ]),
            'val': transforms.Compose([
                transforms.Resize(resize),
                transforms.CenterCrop(resize),
                transforms.ToTensor(), # Torchテンソルに変換
                transforms.Normalize(mean, std) # 色の標準化
            ]),
        }
    
    def __call__(self, img, phase='train'):
        """
        Parameters
        ----------
        phase : 'train' or 'val'
            前処理のモード指定
        """    
        return self.data_transform[phase](img)


# 画像へのファイルパスのリストを作成
def make_datapath_list(phase="train", label="Benign"):
    """
    データのパスを格納したリストを作成

    Parameters
    ----------
    phase : 'train' or 'val'
        訓練データか,検証データかを指定
    
    Returns
    -------
    path_list : list
        データへのパスを格納したリスト
    """

    rootpath = "/data2/EBUS/EBUS動画20220124/ExtractedFrames/"
    target_path = osp.join(rootpath + phase + "/" + label + '/**/*.png')
    # print(target_path)
    # print(len(glob.glob(target_path, recursive=True)))

    # globを利用してサブディレクトリまでファイルパスを取得
    path_list = []
    for path in glob.glob(target_path, recursive=True):
        path_list.append(path)
    
    return path_list


# EBUS画像のDatasetを作成する
class EbusDataset(data.Dataset):
    """
    EBUS画像のDatasetクラス.PyTorchのDatasetクラスを継承

    Attributes
    ----------
    file_list : list
        画像のパスを格納したリスト
    transform : object
        前処理クラスのインスタンス
    phase : 'train' or 'val'
        訓練か検証かを設定する
    """

    def __init__(self, file_list, transform=None, phase='train'):
        self.file_list = file_list
        self.transform = transform
        self.phase = phase
    
    def __len__(self):
        '''画像の枚数を返す'''
        return len(self.file_list)
    
    def __getitem__(self, index):
        '''
        前処理した画像のTensor形式のデータとラベルを取得
        '''

        # index番目の画像をロード
        img_path = self.file_list[index]
        img = Image.open(img_path)
        # 画像の前処理を実施
        img_transformed = self.transform(img, self.phase)
        
        # パスからラベルを判定
        label = 0 if "Benign" in img_path else 1
        
        return img_transformed, label

# モデルを学習させる関数
def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):

    # 演算デバイス設定
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("使用デバイス:", device)
    net.to(device)
    # ネットワークがある程度固定なら高速化
    torch.backends.cudnn.benchmark = True
    training_result = []
    
    # epochのループ
    for epoch in range(num_epochs):
        print('')
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        epoch_result = [epoch + 1]

        # epochごとの学習と検証のループ
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()
            else:
                net.eval()
            
            epoch_loss = 0.0 # epochの損失和
            epoch_corrects = 0 # epochの正解数

            # 未学習時の性能を確かめるため epoch=0 の訓練は省略
            if (epoch == 0) and (phase == 'train'):
                epoch_result.extend([0, 0])
                continue

            # データローダーからミニバッチを取り出すループ
            for inputs, labels in tqdm(dataloaders_dict[phase]):

                # GPUが使えるならGPUへデータ転送
                inputs = inputs.to(device)
                labels = labels.to(device)

                # optimizerを初期化
                optimizer.zero_grad()

                # 順伝搬(forward)計算
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)

                    # 訓練時はバックプロパゲーション
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    # イテレーション結果の計算
                    # lossの合計を更新
                    epoch_loss += loss.item() * inputs.size(0)
                    epoch_corrects += torch.sum(preds == labels.data)
            
            # epochごとのlossと正解率を表示
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            epoch_result.extend([epoch_loss, float(epoch_acc)])
        
        training_result.append(epoch_result)

    return training_result

if __name__ == "__main__":

    # 乱数のシードを設定
    torch.manual_seed(1235)
    np.random.seed(1235)
    random.seed(1235)

    # モデル入力画像の仕様
    size = 224
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    # データセットファイル名の取得
    train_benign_list = make_datapath_list(phase="train", label="Benign")
    train_malignant_list = make_datapath_list(phase="train", label="Malignant")
    val_benign_list = make_datapath_list(phase="val", label="Benign")
    val_malignant_list = make_datapath_list(phase="val", label="Malignant")
    print("Dataset(orig): train benign:%d malignant:%d  val benign:%d malignant:%d" %
        (len(train_benign_list), len(train_malignant_list), len(val_benign_list), len(val_malignant_list) ))
    num_traindata = 1000
    num_valdata = 100
    train_benign_list = random.sample(train_benign_list, num_traindata)
    train_malignant_list = random.sample(train_malignant_list, num_traindata)
    val_benign_list = random.sample(val_benign_list, num_valdata)
    val_malignant_list = random.sample(val_malignant_list, num_valdata)
    print("Dataset(arranged): train benign:%d malignant:%d  val benign:%d malignant:%d" %
        (len(train_benign_list), len(train_malignant_list), len(val_benign_list), len(val_malignant_list) ))

    # データセットの読み込み
    train_list = train_benign_list + train_malignant_list
    val_list = val_benign_list + val_malignant_list
    train_dataset = EbusDataset(
        file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')
    val_dataset = EbusDataset(
        file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')

    # ミニバッチのサイズを指定
    batch_size = 16

    # DataLoaderを作成
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False)

    # 辞書型変数にまとめる
    dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

    use_pretrained = True
    # ResNet18
    net = models.resnet18(pretrained=use_pretrained)
    model_name = "ResNet18"
    net.fc = nn.Linear(in_features=512, out_features=2, bias=True)
    # net = models.resnet101(pretrained=use_pretrained)
    # model_name = "ResNet101"
    # net.fc = nn.Linear(in_features=2048, out_features=2, bias=True)
    # print(net)
    # sys.exit()
    # 学習するレイヤーの指定
    update_param_names = [name for name, param in net.named_parameters() if "fc." in name]
    update_layer = "FC"
    # update_param_names = [name for name, param in net.named_parameters() if "fc." in name or "layer4" in name]
    # update_layer = "FC and Layer4"

    # VGG16
    # net = models.vgg16(pretrained=use_pretrained)
    # model_name = "VGG16"
    # net.classifier[6] = nn.Linear(in_features=4096, out_features=2)
    # update_param_names = ["classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias"]
    # update_param_names = ["classifier.6.weight", "classifier.6.bias"]

    # 訓練モードに設定
    net.train()

    print('ネットワーク設定完了:学習済みの重みをロードし,訓練モードに設定しました')

    # 損失関数の設定
    criterion = nn.CrossEntropyLoss()

    # 転移学習で学習させるパラメータを,変数params_to_updateに格納
    params_to_update = []

    # 学習させるパラメータ名

    # 学習させるパラメータ以外は勾配計算せず固定
    for name, param in net.named_parameters():
        # if True:
        # if name.startswith('classifier'):
        if name in update_param_names:
            param.requires_grad = True
            params_to_update.append(param)
            print("update: ", name)
        else:
            param.requires_grad = False
            print("freeze: ", name)

    # params_to_updateの中身を確認
    print("------------")
    # print(params_to_update)

    # 最適化手法の設定
    learning_rate = 0.001
    optimizer = optim.SGD(params=params_to_update, lr=learning_rate, momentum=0.9)

    # 学習・検証を実行する
    num_epochs = 10
    training_result = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs)

    # 学習カーブを解析
    result_max = np.amax(training_result, axis=0)[1:]
    result_min = np.amin(training_result, axis=0)[1:]
    result_median = np.median(training_result, axis=0)[1:]

    # ログ出力
    f = open("{0:result%Y%m%d_%H%M%S.csv}".format(datetime.datetime.now()), 'w')
    writer = csv.writer(f, lineterminator='\n')
    writer.writerow(['Model','Update Layer','#Training','#Validation','Batch Size','Learning Rate','#Epoch','ValAccMax','ValAccMedian'])
    writer.writerow([model_name, update_layer, num_traindata, num_valdata, batch_size, learning_rate, num_epochs, result_max[3], result_median[3]])
    writer.writerow([])
    writer.writerow(['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc'])
    np.savetxt(f, training_result, fmt=['%.0f', '%.4f', '%.4f', '%.4f', '%.4f'], delimiter=',', newline='\n')
    writer.writerow([])
    f.write('max,')
    np.savetxt(f, result_max.reshape((1, 4)), fmt='%.4f', delimiter=',', newline='\n')
    f.write('min,')
    np.savetxt(f, result_min.reshape((1, 4)), fmt='%.4f', delimiter=',', newline='\n')
    f.write('median,')
    np.savetxt(f, result_median.reshape((1, 4)), fmt='%.4f', delimiter=',', newline='\n')

    # 終了処理
    f.close()
    print('done.')