diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a445398 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +res* diff --git a/README.md b/README.md index 9f3b7ba..8d0bcca 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,23 @@ # EBUS_Classify +## 実行条件 +- 要pytorch gpu + ## ebus_movie_classify.py -動画から切り出したフレーム画像に対して学習 +- 動画から切り出したフレーム画像に対して学習 +### 実行方法 +```bash +# デフォルト条件で実行 +$ python3 bus_movie_classify.py +# ヘルプ表示 +$ python3 bus_movie_classify.py -h +``` ## ebus_classify_multi.py -初期の学習コードを並列実行できるようにしたもの +- 初期の学習コードを並列実行できるようにしたもの ## ebus_classify.ipynb -初期の学習コードのJupyter Note +- 初期の学習コードのJupyter Note ## ebus_classify.py -初期の学習コード +- 初期の学習コード diff --git a/ebus_movie_classify.py b/ebus_movie_classify.py index a86e22c..a7852a5 100644 --- a/ebus_movie_classify.py +++ b/ebus_movie_classify.py @@ -14,6 +14,7 @@ from torchvision import models, transforms import sys import datetime +import argparse # 入力画像の前処理クラス # 訓練時と推論時で処理を変える @@ -200,6 +201,26 @@ if __name__ == "__main__": + # 実行時引数の処理 + parser = argparse.ArgumentParser(description='EBUS分類プログラム') + parser.add_argument('-m', '--model', help='学習モデル名(ResNet18, ResNet101, VGG16)', default='ResNet18') + parser.add_argument('-u', '--update_layer', help='学習する層(FC, All, L4-FC)', default='FC') + parser.add_argument('-e', '--epochs', help='エポック数', type=int, default=50) + parser.add_argument('-b', '--batch_size', help='ミニバッチサイズ', type=int, default=16) + parser.add_argument('-t', '--num_traindata', help='学習データ数', type=int, default=1000) + parser.add_argument('-v', '--num_valdata', help='検証データ数', type=int, default=100) + parser.add_argument('-l', '--learning_rate', help='学習率', type=float, default=0.001) + parser.add_argument('-o', '--output', help='出力ファイル名', default="{0:result%Y%m%d_%H%M%S.csv}".format(datetime.datetime.now())) + args = parser.parse_args() + model_name = args.model + update_layer = args.update_layer + num_epochs = args.epochs + batch_size = args.batch_size + num_traindata = args.num_traindata + num_valdata = args.num_valdata + learning_rate = args.learning_rate + output_filename = args.output + # 乱数のシードを設定 torch.manual_seed(1235) np.random.seed(1235) @@ -217,8 +238,6 @@ val_malignant_list = make_datapath_list(phase="val", label="Malignant") print("Dataset(orig): train benign:%d malignant:%d val benign:%d malignant:%d" % (len(train_benign_list), len(train_malignant_list), len(val_benign_list), len(val_malignant_list) )) - num_traindata = 1000 - num_valdata = 100 train_benign_list = random.sample(train_benign_list, num_traindata) train_malignant_list = random.sample(train_malignant_list, num_traindata) val_benign_list = random.sample(val_benign_list, num_valdata) @@ -234,9 +253,6 @@ val_dataset = EbusDataset( file_list=val_list, transform=ImageTransform(size, mean, std), phase='val') - # ミニバッチのサイズを指定 - batch_size = 16 - # DataLoaderを作成 train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True) @@ -246,32 +262,38 @@ # 辞書型変数にまとめる dataloaders_dict = {"train": train_dataloader, "val": val_dataloader} - use_pretrained = True - # ResNet18 - net = models.resnet18(pretrained=use_pretrained) - model_name = "ResNet18" - net.fc = nn.Linear(in_features=512, out_features=2, bias=True) - # net = models.resnet101(pretrained=use_pretrained) - # model_name = "ResNet101" - # net.fc = nn.Linear(in_features=2048, out_features=2, bias=True) - # print(net) - # sys.exit() - # 学習するレイヤーの指定 - update_param_names = [name for name, param in net.named_parameters() if "fc." in name] - update_layer = "FC" - # update_param_names = [name for name, param in net.named_parameters() if "fc." in name or "layer4" in name] - # update_layer = "FC and Layer4" - # VGG16 - # net = models.vgg16(pretrained=use_pretrained) - # model_name = "VGG16" - # net.classifier[6] = nn.Linear(in_features=4096, out_features=2) - # update_param_names = ["classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias"] - # update_param_names = ["classifier.6.weight", "classifier.6.bias"] + # モデル設定 + use_pretrained = True + if model_name.lower() == "resnet18": + net = models.resnet18(pretrained=use_pretrained) + net.fc = nn.Linear(in_features=512, out_features=2, bias=True) + elif model_name.lower() == "resnet101": + net = models.resnet101(pretrained=use_pretrained) + net.fc = nn.Linear(in_features=2048, out_features=2, bias=True) + elif model_name.lower() == "vgg16": + net = models.vgg16(pretrained=use_pretrained) + net.classifier[6] = nn.Linear(in_features=4096, out_features=2) + else: + print('unknown model : ' + model_name) + sys.exit() + + # 学習するレイヤーの指定 + if update_layer.lower() == "fc": + if model_name.lower() == "vgg16": + update_param_names = ["classifier.6.weight", "classifier.6.bias"] + else: + update_param_names = [name for name, param in net.named_parameters() if "fc." in name] + elif update_layer.lower() == "l4-fc": + update_param_names = [name for name, param in net.named_parameters() if "fc." in name or "layer4" in name] + elif update_layer.lower() == "all": + update_param_names = [name for name, param in net.named_parameters()] + else: + print('unknown update layer setting : ' + update_layer) + sys.exit() # 訓練モードに設定 net.train() - print('ネットワーク設定完了:学習済みの重みをロードし,訓練モードに設定しました') # 損失関数の設定 @@ -299,11 +321,9 @@ # print(params_to_update) # 最適化手法の設定 - learning_rate = 0.001 optimizer = optim.SGD(params=params_to_update, lr=learning_rate, momentum=0.9) # 学習・検証を実行する - num_epochs = 10 training_result = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs) # 学習カーブを解析 @@ -312,7 +332,7 @@ result_median = np.median(training_result, axis=0)[1:] # ログ出力 - f = open("{0:result%Y%m%d_%H%M%S.csv}".format(datetime.datetime.now()), 'w') + f = open(output_filename, 'w') writer = csv.writer(f, lineterminator='\n') writer.writerow(['Model','Update Layer','#Training','#Validation','Batch Size','Learning Rate','#Epoch','ValAccMax','ValAccMedian']) writer.writerow([model_name, update_layer, num_traindata, num_valdata, batch_size, learning_rate, num_epochs, result_max[3], result_median[3]]) @@ -329,4 +349,5 @@ # 終了処理 f.close() - print('done.') + print('') + print('Done. Result output to ' + output_filename)