{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "35be6f81-5857-411d-b142-2a0155da87ee",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../src\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0b3d8d4d-3e99-452c-bb77-973bd4f6b392",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import os.path as osp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"import torch.nn as nn\n",
"from tensorboardX import SummaryWriter\n",
"from sklearn.metrics import confusion_matrix\n",
"import seaborn as sn\n",
"from tqdm import tqdm\n",
"from torchvision import transforms\n",
"from datasets import WasteNetDataset\n",
"from sklearn.model_selection import train_test_split\n",
"from torch.utils.data import DataLoader\n",
"from efficientnet_pytorch import EfficientNet\n",
"from efficientnet_pytorch.utils import Conv2dStaticSamePadding\n",
"\n",
"BATCH_SIZE = 8\n",
"IMG_WIDTH = 384\n",
"IMG_HEIGHT = 288\n",
"SHOW_INFO_FREQ = 100\n",
"VALID_SEQUENCE_NUM = 0\n",
"\n",
"# 学習パラメータ系\n",
"EPOCH_NUM = 10"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "27806384-7a64-405a-8201-32124538d176",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Random Seed: 999\n"
]
}
],
"source": [
"# 再現性のためのrandom seedを設定する\n",
"manual_seed = 999\n",
"print(\"Random Seed: \", manual_seed)\n",
"random.seed(manual_seed)\n",
"torch.manual_seed(manual_seed)\n",
"\n",
"# データ選ぶ部\n",
"data_table = pd.read_csv(\"../input/train.csv\", header=None).values.tolist()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d5975228-1443-4041-aec7-22532417a96e",
"metadata": {},
"outputs": [],
"source": [
"valid_sequence_num = VALID_SEQUENCE_NUM"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "4e2c6727-7801-432a-95a9-6c702e50eb14",
"metadata": {},
"outputs": [],
"source": [
"train_sets, test_sets = [], []\n",
"for row in data_table:\n",
" sample_file_path = row[0]\n",
" sequence_name = osp.basename(osp.dirname(sample_file_path))\n",
" if sequence_name == f\"sequence{valid_sequence_num}\":\n",
" test_sets.append(row)\n",
" else:\n",
" train_sets.append(row)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "6ae69ff1-9c7d-43a7-be73-557aa55f9eef",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"31821"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_sets)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "0c28cd05-fa33-4876-b5cb-ef5d10ef6a0e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"552"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(test_sets)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "e4ae69ca-05c9-4c7b-b6b1-63aa71145927",
"metadata": {},
"outputs": [],
"source": [
"true_label_num = sum([1 for row in test_sets if row[-1]])\n",
"false_label_num = sum([1 for row in test_sets if not(row[-1])])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "ec747707-582c-47a9-9551-f8e7b979c298",
"metadata": {},
"outputs": [],
"source": [
"true_test_sets = [row for row in test_sets if row[-1]]\n",
"false_test_sets = [row for row in test_sets if not(row[-1])]"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "ff41beaf-e378-4ecc-8438-18ad08d7c15f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"155"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(true_test_sets)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "1e5b2719-72a6-4994-a4cf-b46fd00bd640",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"397"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"false_label_num"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "cd28b3fa-6779-4d49-9183-654779f0b933",
"metadata": {},
"outputs": [],
"source": [
"min_data_num = min(len(true_test_sets), len(false_test_sets))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "838b0106-e58d-4ac2-8188-de23a738a7ee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"155"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"random.shuffle()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}