{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "211db4dd-7709-4bcc-a5ee-60fc64929069",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../src\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fb9b463b-66d9-4ead-9f39-82e2ea1408c6",
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"from glob import glob\n",
"import os\n",
"import os.path as osp\n",
"import random\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import torch\n",
"import torch.nn as nn\n",
"from tensorboardX import SummaryWriter\n",
"from sklearn.metrics import confusion_matrix\n",
"import seaborn as sn\n",
"from tqdm import tqdm\n",
"import argparse\n",
"from torchvision import transforms\n",
"from datasets import WasteNetDataset\n",
"from torch.utils.data import DataLoader\n",
"from utils import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86b21bc3-62b3-4a11-88fe-1c091a00f5b8",
"metadata": {},
"outputs": [],
"source": [
"# import pickle\n",
"# for i in range(9):\n",
"# data_table = pd.read_csv(\"../input/train.csv\", header=None).values.tolist()\n",
"# _, test_dset = train_test_split_by_sequence(data_table, i, True)\n",
"# with open(f\"test_dset{i}.pkl\", \"wb\") as f:\n",
"# pickle.dump(test_dset, f)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "59770128-4137-463b-af68-ca02463e97b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Use device: cuda\n"
]
}
],
"source": [
"# 固定値系\n",
"IMG_WIDTH = 384\n",
"IMG_HEIGHT = 288\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"out_path = f\"D:\\Deep_Learning\\WasteNet\\logs\\pred_csv\"\n",
"print(f\"Use device: {device}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "91c591eb-a592-436c-9289-e98f2841789d",
"metadata": {},
"outputs": [],
"source": [
"# パラメータ系\n",
"base_path = r\"D:\\Deep_Learning\\WasteNet\\logs\\undersample+bagging\\effnetv2_m_7\"\n",
"bagging_num = 7\n",
"model = load_effnetv2().to(device)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fb8d7305-3a0b-4902-b4ff-6cfbe857dbbd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 11%|██████████████████████▎ | 1/9 [00:12<01:39, 12.41s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[acc] 0.7290322580645161\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 22%|████████████████████████████████████████████▋ | 2/9 [00:56<03:35, 30.76s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[acc] 0.7259036144578314\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 33%|███████████████████████████████████████████████████████████████████ | 3/9 [01:19<02:43, 27.23s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[acc] 0.8863636363636364\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 44%|█████████████████████████████████████████████████████████████████████████████████████████▎ | 4/9 [01:42<02:08, 25.68s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[acc] 0.7586805555555556\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 56%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 5/9 [02:31<02:16, 34.09s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[acc] 0.7167414050822123\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 67%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 6/9 [02:43<01:20, 26.73s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[acc] 0.7739130434782608\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 7/9 [02:51<00:40, 20.40s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[acc] 0.7867647058823529\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 89%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 8/9 [03:00<00:16, 16.72s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[acc] 0.8296703296703297\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [03:09<00:00, 21.00s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[acc] 0.7298850574712644\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"for seq_num in tqdm(range(9)):\n",
" os.makedirs(osp.join(out_path, str(bagging_num)), exist_ok=True)\n",
"\n",
" pth_path = glob(osp.join(base_path, f\"sequence{seq_num}\", \"*.pth\"))[0]\n",
" model.load_state_dict(torch.load(pth_path, map_location=device), strict=False)\n",
" model.eval()\n",
"\n",
" transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
" ])\n",
"\n",
" with open(f\"test_dset{seq_num}.pkl\", \"rb\") as f:\n",
" test_dset = pickle.load(f)\n",
" test_datasets = WasteNetDataset(test_dset, IMG_WIDTH, IMG_HEIGHT, transform, True)\n",
" test_loader = DataLoader(test_datasets, batch_size=4, num_workers=0)\n",
"\n",
" out_df = []\n",
" y_pred, y_true = [], []\n",
" correct_num = 0\n",
" pred_target_num = 0\n",
" for img, labels, path1, path2, path3 in test_loader:\n",
" img, labels = img.to(device), labels.to(device)\n",
" out = nn.functional.softmax(model(img), dim=1)\n",
" for batch_i in range(img.shape[0]):\n",
" out_df_row = [path1[batch_i], path2[batch_i], path3[batch_i]]\n",
" out_df_row.extend([float(out[batch_i][0]), float(out[batch_i][1]), int(labels[batch_i])])\n",
" out_df.append(out_df_row)\n",
" pred_target_num += out.shape[0]\n",
" pred_index = torch.argmax(out, dim=1)\n",
" correct_num += torch.sum(pred_index == labels).item()\n",
" \n",
" y_pred.extend(pred_index.data.cpu().numpy())\n",
" y_true.extend(labels.data.cpu().numpy())\n",
" print(f\"[acc] {correct_num / pred_target_num}\")\n",
" out_df = pd.DataFrame(out_df)\n",
" out_df.columns = [\"path1\", \"path2\", \"path3\", \"pred0\", \"pred2\", \"label\"]\n",
" out_df.to_csv(osp.join(out_path, str(bagging_num), f\"bagging{bagging_num}, sequence{seq_num}.csv\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}