Newer
Older
RARP / lbp.ipynb
@delAguila delAguila on 22 Nov 2024 12 KB init comit
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import Models as M\n",
    "import Loaders\n",
    "import defs\n",
    "import lightning as L\n",
    "import torchmetrics\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint\n",
    "from lightning.pytorch.loggers import TensorBoardLogger\n",
    "from tqdm.notebook import tqdm\n",
    "import warnings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "setup_seed(2023)\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "Fold = 4\n",
    "numWorkers = 0\n",
    "batchSize = 17"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Dataset = Loaders.RARP_DatasetCreator(\n",
    "    \"./DataSet_Crop1\",\n",
    "    FoldSeed=505,\n",
    "    createFile=True,\n",
    "    SavePath=\"./DataSetCrop1\",\n",
    "    Fold=5,\n",
    "    removeBlackBar=False\n",
    ")\n",
    "cropSize = 256\n",
    "\n",
    "Dataset.CreateFolds()\n",
    "\n",
    "checkPtCallback = ModelCheckpoint(monitor='val_acc', save_top_k=10, mode='max')\n",
    "    \n",
    "traintransform = torch.nn.Sequential(\n",
    "    transforms.Normalize(Dataset.mean, Dataset.std),\n",
    "    transforms.Resize(cropSize, antialias=True),\n",
    "    transforms.RandomHorizontalFlip(0.6),\n",
    "    transforms.RandomAffine(\n",
    "        degrees=(-5, 5), translate=(0, 0.05), scale=(0.9, 1.1), \n",
    "        fill=5\n",
    "    ),\n",
    "    transforms.RandomResizedCrop((224, 224), scale=(0.35, 1), antialias=True),\n",
    ").to(device)\n",
    "\n",
    "valtransform = torch.nn.Sequential(\n",
    "    transforms.Resize(256, antialias=True),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.Normalize(Dataset.mean, Dataset.std)\n",
    ").to(device)\n",
    "\n",
    "testtransform =  torch.nn.Sequential(\n",
    "    transforms.Resize(256, antialias=True),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.Normalize(Dataset.mean, Dataset.std)\n",
    ").to(device)\n",
    "\n",
    "rootFile = Dataset.CVS_File.parent.parent/f\"fold_{Fold}\"\n",
    "\n",
    "trainDataset = torchvision.datasets.DatasetFolder(\n",
    "    str (rootFile/\"train\"),\n",
    "    loader=defs.load_file_tensor,\n",
    "    extensions=\"npy\",\n",
    "    transform=traintransform\n",
    ")\n",
    "\n",
    "valDataset = torchvision.datasets.DatasetFolder(\n",
    "    str (rootFile/\"val\"),\n",
    "    loader=defs.load_file_tensor,\n",
    "    extensions=\"npy\",\n",
    "    transform=valtransform\n",
    ")\n",
    "\n",
    "testDataset = torchvision.datasets.DatasetFolder(\n",
    "    str (rootFile/\"test\"),\n",
    "    loader=defs.load_file_tensor,\n",
    "    extensions=\"npy\",\n",
    "    transform=testtransform\n",
    ")\n",
    "\n",
    "Train_DataLoader = DataLoader(\n",
    "    trainDataset, \n",
    "    batch_size=batchSize, \n",
    "    num_workers=numWorkers, \n",
    "    shuffle=True, \n",
    "    pin_memory=True)\n",
    "Val_DataLoader = DataLoader(\n",
    "    valDataset, \n",
    "    batch_size=batchSize, \n",
    "    num_workers=numWorkers, \n",
    "    shuffle=False, \n",
    "    pin_memory=True)\n",
    "Test_DataLoader = DataLoader(\n",
    "    testDataset, \n",
    "    batch_size=batchSize, \n",
    "    num_workers=numWorkers, \n",
    "    shuffle=False, \n",
    "    pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img, label = next(iter(Train_DataLoader))\n",
    "img = img.float().to(device)\n",
    "label = label.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelPath = Path(\"./log_restnet18_X6/lightning_logs/version_14/checkpoints/epoch=22-step=92.ckpt\")\n",
    "\n",
    "model = M.ResNet18_UNet(M.RARP_NVB_ResNet18.load_from_checkpoint(modelPath, strict=False)).to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p, u = model(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.sigmoid(p), u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "_, axis = plt.subplots(2, 2, figsize=(9, 9))\n",
    "\n",
    "img = u[0].cpu()\n",
    "img = img.detach().numpy().transpose((1, 2, 0))\n",
    "img2 = np.clip((Dataset.std * img + Dataset.mean) / 255, 0, 1)\n",
    "\n",
    "axis[0][0].imshow(img2, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "listCkpt = [\n",
    "    Path(\"./log_ResNet50_X2/lightning_logs/version_4/checkpoints/epoch=13-step=28.ckpt\"),\n",
    "    Path(\"./log_ResNet50_X2/lightning_logs/version_6/checkpoints/epoch=9-step=20.ckpt\"),\n",
    "    Path(\"./log_ResNet50_X2/lightning_logs/version_11/checkpoints/epoch=8-step=18.ckpt\")\n",
    "]\n",
    "\n",
    "models = [M.RARP_NVB_ResNet50.load_from_checkpoint(ckpt, strict=False) for ckpt in listCkpt]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [m.eval().to(device) for m in models]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = [torch.sigmoid(m(img)) for m in models]\n",
    "torch.cat(p, dim=1), label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.cat(p, dim=1)\n",
    "test = label.unsqueeze(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data1 = data.mean(dim=1).detach()\n",
    "label1 = label.float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torchmetrics.Accuracy('binary').to(device)(data1, label1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.randn(100, 10)\n",
    "test = torch.randn(1, 10)\n",
    "\n",
    "dist = torch.norm(data - test, dim=1, p=None)\n",
    "knn = dist.topk(2, largest=False)\n",
    "knn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any\n",
    "\n",
    "\n",
    "class RARP_Ensemble_ResNet50(L.LightningModule):\n",
    "    def __init__(self, models:list):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.ListModels = models        \n",
    "        input_p = len(self.ListModels)\n",
    "        \n",
    "        self.classifier = torch.nn.Sequential(\n",
    "            torch.nn.Linear(input_p, input_p + 1),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(input_p + 1, input_p),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(input_p, 1)            \n",
    "        )\n",
    "        \n",
    "        #self.classifier = torch.nn.Linear(input_p, 1)\n",
    "        \n",
    "        #self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) \n",
    "        self.lossFN = torch.nn.MSELoss()\n",
    "        self.train_acc = torchmetrics.Accuracy('binary')\n",
    "        self.val_acc = torchmetrics.Accuracy('binary')\n",
    "        self.test_acc = torchmetrics.Accuracy('binary')\n",
    "        self.f1Score = torchmetrics.F1Score('binary')\n",
    "        self.f1ScoreTest = torchmetrics.F1Score('binary')\n",
    "\n",
    "    def forward(self, data):\n",
    "        data = data.float()\n",
    "        with torch.no_grad():\n",
    "            p = [m(data) for m in self.ListModels]\n",
    "            p = torch.sigmoid(torch.cat(p, dim=1))\n",
    "        x = self.classifier(p)\n",
    "        return x\n",
    "    \n",
    "    def _shared_step(self, batch):\n",
    "        img, label = batch\n",
    "        label = label.float()\n",
    "        pred = self.forward(img).flatten()\n",
    "        loss = self.lossFN(pred, label)\n",
    "        \n",
    "        predicted_labels = torch.sigmoid(pred)\n",
    "        \n",
    "        return loss, label, predicted_labels\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        loss, true_labels, predicted_labels = self._shared_step(batch)\n",
    "\n",
    "        self.log(\"train_loss\", loss)\n",
    "        self.train_acc.update(predicted_labels, true_labels)\n",
    "        self.log(\"train_acc\", self.train_acc, on_epoch=True, on_step=False)\n",
    "\n",
    "        return loss\n",
    "    \n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        loss, true_labels, predicted_labels = self._shared_step(batch)\n",
    "        self.log(\"val_loss\", loss)\n",
    "        self.val_acc.update(predicted_labels, true_labels)\n",
    "        self.f1Score.update(predicted_labels, true_labels)\n",
    "        self.log(\"val_acc\", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)\n",
    "        self.log(\"val_f1\", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        loss, true_labels, predicted_labels = self._shared_step(batch)\n",
    "        self.test_acc.update(predicted_labels, true_labels)\n",
    "        self.f1ScoreTest.update(predicted_labels, true_labels)\n",
    "        self.log(\"test_acc\", self.test_acc, on_epoch=True, on_step=False)\n",
    "        self.log(\"test_f1\", self.f1ScoreTest, on_epoch=True, on_step=False)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=1e-4) \n",
    "        return [optimizer]\n",
    "        \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Model = RARP_Ensemble_ResNet50(models)\n",
    "LogFileName = \"log_debug\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Model.classifier.parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "warnings.simplefilter(\"ignore\")\n",
    "\n",
    "trainer = L.Trainer(\n",
    "    accelerator='gpu', \n",
    "    devices=1, \n",
    "    logger=TensorBoardLogger(save_dir=f\"./{LogFileName}\"),\n",
    "    log_every_n_steps=1, \n",
    "    callbacks=checkPtCallback,\n",
    "    max_epochs=50,\n",
    ")\n",
    "\n",
    "trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)\n",
    "        #trainer.callbacks\n",
    "trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path=\"best\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "TestRARP",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}