Newer
Older
WasteNet / notebooks / pred_with_test.ipynb
@sato sato on 1 Mar 2022 11 KB 最初のコミット
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "211db4dd-7709-4bcc-a5ee-60fc64929069",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../src\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fb9b463b-66d9-4ead-9f39-82e2ea1408c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from glob import glob\n",
    "import os\n",
    "import os.path as osp\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from tensorboardX import SummaryWriter\n",
    "from sklearn.metrics import confusion_matrix\n",
    "import seaborn as sn\n",
    "from tqdm import tqdm\n",
    "import argparse\n",
    "from torchvision import transforms\n",
    "from datasets import WasteNetDataset\n",
    "from torch.utils.data import DataLoader\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86b21bc3-62b3-4a11-88fe-1c091a00f5b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle\n",
    "# for i in range(9):\n",
    "#     data_table = pd.read_csv(\"../input/train.csv\", header=None).values.tolist()\n",
    "#     _, test_dset = train_test_split_by_sequence(data_table, i, True)\n",
    "#     with open(f\"test_dset{i}.pkl\", \"wb\") as f:\n",
    "#         pickle.dump(test_dset, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "59770128-4137-463b-af68-ca02463e97b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Use device: cuda\n"
     ]
    }
   ],
   "source": [
    "# 固定値系\n",
    "IMG_WIDTH = 384\n",
    "IMG_HEIGHT = 288\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "out_path = f\"D:\\Deep_Learning\\WasteNet\\logs\\pred_csv\"\n",
    "print(f\"Use device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "91c591eb-a592-436c-9289-e98f2841789d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# パラメータ系\n",
    "base_path = r\"D:\\Deep_Learning\\WasteNet\\logs\\undersample+bagging\\effnetv2_m_7\"\n",
    "bagging_num = 7\n",
    "model = load_effnetv2().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fb8d7305-3a0b-4902-b4ff-6cfbe857dbbd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|██████████████████████▎                                                                                                                                                                                  | 1/9 [00:12<01:39, 12.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[acc] 0.7290322580645161\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|████████████████████████████████████████████▋                                                                                                                                                            | 2/9 [00:56<03:35, 30.76s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[acc] 0.7259036144578314\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███████████████████████████████████████████████████████████████████                                                                                                                                      | 3/9 [01:19<02:43, 27.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[acc] 0.8863636363636364\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|█████████████████████████████████████████████████████████████████████████████████████████▎                                                                                                               | 4/9 [01:42<02:08, 25.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[acc] 0.7586805555555556\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                         | 5/9 [02:31<02:16, 34.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[acc] 0.7167414050822123\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 67%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                   | 6/9 [02:43<01:20, 26.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[acc] 0.7739130434782608\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                            | 7/9 [02:51<00:40, 20.40s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[acc] 0.7867647058823529\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 89%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                      | 8/9 [03:00<00:16, 16.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[acc] 0.8296703296703297\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [03:09<00:00, 21.00s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[acc] 0.7298850574712644\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for seq_num in tqdm(range(9)):\n",
    "    os.makedirs(osp.join(out_path, str(bagging_num)), exist_ok=True)\n",
    "\n",
    "    pth_path = glob(osp.join(base_path, f\"sequence{seq_num}\", \"*.pth\"))[0]\n",
    "    model.load_state_dict(torch.load(pth_path, map_location=device), strict=False)\n",
    "    model.eval()\n",
    "\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ])\n",
    "\n",
    "    with open(f\"test_dset{seq_num}.pkl\", \"rb\") as f:\n",
    "        test_dset = pickle.load(f)\n",
    "    test_datasets = WasteNetDataset(test_dset, IMG_WIDTH, IMG_HEIGHT, transform, True)\n",
    "    test_loader = DataLoader(test_datasets, batch_size=4, num_workers=0)\n",
    "\n",
    "    out_df = []\n",
    "    y_pred, y_true = [], []\n",
    "    correct_num = 0\n",
    "    pred_target_num = 0\n",
    "    for img, labels, path1, path2, path3 in test_loader:\n",
    "        img, labels = img.to(device), labels.to(device)\n",
    "        out = nn.functional.softmax(model(img), dim=1)\n",
    "        for batch_i in range(img.shape[0]):\n",
    "            out_df_row = [path1[batch_i], path2[batch_i], path3[batch_i]]\n",
    "            out_df_row.extend([float(out[batch_i][0]), float(out[batch_i][1]), int(labels[batch_i])])\n",
    "            out_df.append(out_df_row)\n",
    "        pred_target_num += out.shape[0]\n",
    "        pred_index = torch.argmax(out, dim=1)\n",
    "        correct_num += torch.sum(pred_index == labels).item()\n",
    "        \n",
    "        y_pred.extend(pred_index.data.cpu().numpy())\n",
    "        y_true.extend(labels.data.cpu().numpy())\n",
    "    print(f\"[acc] {correct_num / pred_target_num}\")\n",
    "    out_df = pd.DataFrame(out_df)\n",
    "    out_df.columns = [\"path1\", \"path2\", \"path3\", \"pred0\", \"pred2\", \"label\"]\n",
    "    out_df.to_csv(osp.join(out_path, str(bagging_num), f\"bagging{bagging_num}, sequence{seq_num}.csv\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}