{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import transforms\n",
"import numpy as np\n",
"from pathlib import Path\n",
"import Models as M\n",
"import Loaders\n",
"import defs\n",
"import lightning as L\n",
"import torchmetrics\n",
"from lightning.pytorch.callbacks import ModelCheckpoint\n",
"from lightning.pytorch.loggers import TensorBoardLogger\n",
"from tqdm.notebook import tqdm\n",
"import warnings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def setup_seed(seed):\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
" np.random.seed(seed)\n",
" torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"setup_seed(2023)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"Fold = 4\n",
"numWorkers = 0\n",
"batchSize = 17"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Dataset = Loaders.RARP_DatasetCreator(\n",
" \"./DataSet_Crop1\",\n",
" FoldSeed=505,\n",
" createFile=True,\n",
" SavePath=\"./DataSetCrop1\",\n",
" Fold=5,\n",
" removeBlackBar=False\n",
")\n",
"cropSize = 256\n",
"\n",
"Dataset.CreateFolds()\n",
"\n",
"checkPtCallback = ModelCheckpoint(monitor='val_acc', save_top_k=10, mode='max')\n",
" \n",
"traintransform = torch.nn.Sequential(\n",
" transforms.Normalize(Dataset.mean, Dataset.std),\n",
" transforms.Resize(cropSize, antialias=True),\n",
" transforms.RandomHorizontalFlip(0.6),\n",
" transforms.RandomAffine(\n",
" degrees=(-5, 5), translate=(0, 0.05), scale=(0.9, 1.1), \n",
" fill=5\n",
" ),\n",
" transforms.RandomResizedCrop((224, 224), scale=(0.35, 1), antialias=True),\n",
").to(device)\n",
"\n",
"valtransform = torch.nn.Sequential(\n",
" transforms.Resize(256, antialias=True),\n",
" transforms.CenterCrop(224),\n",
" transforms.Normalize(Dataset.mean, Dataset.std)\n",
").to(device)\n",
"\n",
"testtransform = torch.nn.Sequential(\n",
" transforms.Resize(256, antialias=True),\n",
" transforms.CenterCrop(224),\n",
" transforms.Normalize(Dataset.mean, Dataset.std)\n",
").to(device)\n",
"\n",
"rootFile = Dataset.CVS_File.parent.parent/f\"fold_{Fold}\"\n",
"\n",
"trainDataset = torchvision.datasets.DatasetFolder(\n",
" str (rootFile/\"train\"),\n",
" loader=defs.load_file_tensor,\n",
" extensions=\"npy\",\n",
" transform=traintransform\n",
")\n",
"\n",
"valDataset = torchvision.datasets.DatasetFolder(\n",
" str (rootFile/\"val\"),\n",
" loader=defs.load_file_tensor,\n",
" extensions=\"npy\",\n",
" transform=valtransform\n",
")\n",
"\n",
"testDataset = torchvision.datasets.DatasetFolder(\n",
" str (rootFile/\"test\"),\n",
" loader=defs.load_file_tensor,\n",
" extensions=\"npy\",\n",
" transform=testtransform\n",
")\n",
"\n",
"Train_DataLoader = DataLoader(\n",
" trainDataset, \n",
" batch_size=batchSize, \n",
" num_workers=numWorkers, \n",
" shuffle=True, \n",
" pin_memory=True)\n",
"Val_DataLoader = DataLoader(\n",
" valDataset, \n",
" batch_size=batchSize, \n",
" num_workers=numWorkers, \n",
" shuffle=False, \n",
" pin_memory=True)\n",
"Test_DataLoader = DataLoader(\n",
" testDataset, \n",
" batch_size=batchSize, \n",
" num_workers=numWorkers, \n",
" shuffle=False, \n",
" pin_memory=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"img, label = next(iter(Train_DataLoader))\n",
"img = img.float().to(device)\n",
"label = label.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"modelPath = Path(\"./log_restnet18_X6/lightning_logs/version_14/checkpoints/epoch=22-step=92.ckpt\")\n",
"\n",
"model = M.ResNet18_UNet(M.RARP_NVB_ResNet18.load_from_checkpoint(modelPath, strict=False)).to(device)\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"p, u = model(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.sigmoid(p), u"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"_, axis = plt.subplots(2, 2, figsize=(9, 9))\n",
"\n",
"img = u[0].cpu()\n",
"img = img.detach().numpy().transpose((1, 2, 0))\n",
"img2 = np.clip((Dataset.std * img + Dataset.mean) / 255, 0, 1)\n",
"\n",
"axis[0][0].imshow(img2, )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"listCkpt = [\n",
" Path(\"./log_ResNet50_X2/lightning_logs/version_4/checkpoints/epoch=13-step=28.ckpt\"),\n",
" Path(\"./log_ResNet50_X2/lightning_logs/version_6/checkpoints/epoch=9-step=20.ckpt\"),\n",
" Path(\"./log_ResNet50_X2/lightning_logs/version_11/checkpoints/epoch=8-step=18.ckpt\")\n",
"]\n",
"\n",
"models = [M.RARP_NVB_ResNet50.load_from_checkpoint(ckpt, strict=False) for ckpt in listCkpt]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"models = [m.eval().to(device) for m in models]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"p = [torch.sigmoid(m(img)) for m in models]\n",
"torch.cat(p, dim=1), label"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = torch.cat(p, dim=1)\n",
"test = label.unsqueeze(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data1 = data.mean(dim=1).detach()\n",
"label1 = label.float()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torchmetrics.Accuracy('binary').to(device)(data1, label1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = torch.randn(100, 10)\n",
"test = torch.randn(1, 10)\n",
"\n",
"dist = torch.norm(data - test, dim=1, p=None)\n",
"knn = dist.topk(2, largest=False)\n",
"knn"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any\n",
"\n",
"\n",
"class RARP_Ensemble_ResNet50(L.LightningModule):\n",
" def __init__(self, models:list):\n",
" super().__init__()\n",
" \n",
" self.ListModels = models \n",
" input_p = len(self.ListModels)\n",
" \n",
" self.classifier = torch.nn.Sequential(\n",
" torch.nn.Linear(input_p, input_p + 1),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(input_p + 1, input_p),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Linear(input_p, 1) \n",
" )\n",
" \n",
" #self.classifier = torch.nn.Linear(input_p, 1)\n",
" \n",
" #self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) \n",
" self.lossFN = torch.nn.MSELoss()\n",
" self.train_acc = torchmetrics.Accuracy('binary')\n",
" self.val_acc = torchmetrics.Accuracy('binary')\n",
" self.test_acc = torchmetrics.Accuracy('binary')\n",
" self.f1Score = torchmetrics.F1Score('binary')\n",
" self.f1ScoreTest = torchmetrics.F1Score('binary')\n",
"\n",
" def forward(self, data):\n",
" data = data.float()\n",
" with torch.no_grad():\n",
" p = [m(data) for m in self.ListModels]\n",
" p = torch.sigmoid(torch.cat(p, dim=1))\n",
" x = self.classifier(p)\n",
" return x\n",
" \n",
" def _shared_step(self, batch):\n",
" img, label = batch\n",
" label = label.float()\n",
" pred = self.forward(img).flatten()\n",
" loss = self.lossFN(pred, label)\n",
" \n",
" predicted_labels = torch.sigmoid(pred)\n",
" \n",
" return loss, label, predicted_labels\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
"\n",
" self.log(\"train_loss\", loss)\n",
" self.train_acc.update(predicted_labels, true_labels)\n",
" self.log(\"train_acc\", self.train_acc, on_epoch=True, on_step=False)\n",
"\n",
" return loss\n",
" \n",
" def validation_step(self, batch, batch_idx):\n",
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
" self.log(\"val_loss\", loss)\n",
" self.val_acc.update(predicted_labels, true_labels)\n",
" self.f1Score.update(predicted_labels, true_labels)\n",
" self.log(\"val_acc\", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)\n",
" self.log(\"val_f1\", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)\n",
"\n",
" return loss\n",
"\n",
" def test_step(self, batch, batch_idx):\n",
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
" self.test_acc.update(predicted_labels, true_labels)\n",
" self.f1ScoreTest.update(predicted_labels, true_labels)\n",
" self.log(\"test_acc\", self.test_acc, on_epoch=True, on_step=False)\n",
" self.log(\"test_f1\", self.f1ScoreTest, on_epoch=True, on_step=False)\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.classifier.parameters(), lr=1e-4) \n",
" return [optimizer]\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Model = RARP_Ensemble_ResNet50(models)\n",
"LogFileName = \"log_debug\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Model.classifier.parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"warnings.simplefilter(\"ignore\")\n",
"\n",
"trainer = L.Trainer(\n",
" accelerator='gpu', \n",
" devices=1, \n",
" logger=TensorBoardLogger(save_dir=f\"./{LogFileName}\"),\n",
" log_every_n_steps=1, \n",
" callbacks=checkPtCallback,\n",
" max_epochs=50,\n",
")\n",
"\n",
"trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)\n",
" #trainer.callbacks\n",
"trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path=\"best\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "TestRARP",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}