Newer
Older
RARP / ROI_Extract_Seg.ipynb
@delAguila delAguila on 20 May 40 KB Video Extraf frame
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchmetrics\n",
    "import torchvision\n",
    "import cv2\n",
    "import numpy as np\n",
    "import albumentations as A\n",
    "from albumentations.pytorch import ToTensorV2\n",
    "from torch.utils.data import DataLoader\n",
    "from lightning.pytorch import seed_everything\n",
    "import lightning.pytorch.callbacks as callbk\n",
    "from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger\n",
    "from tqdm.notebook import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "import Loaders\n",
    "import defs\n",
    "import lightning as L\n",
    "import json\n",
    "import copy\n",
    "from torch.optim.lr_scheduler import LambdaLR\n",
    "import van\n",
    "\n",
    "def setup_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(seed)\n",
    "    seed_everything(seed, workers=True)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    \n",
    "def Denorlalize (img:torch.Tensor, std, mean):\n",
    "    ImgNumpy = img.numpy().transpose((1, 2, 0))\n",
    "    ImgNumpy = np.clip((std * ImgNumpy + mean) , 0, 1)\n",
    "    ImgNumpy = ImgNumpy[...,::-1].copy()\n",
    "    \n",
    "    return ImgNumpy\n",
    "\n",
    "def visualize_augmentations(dataset, idx=0, samples=5):\n",
    "    dataset = copy.deepcopy(dataset)\n",
    "    dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])\n",
    "    _, ax = plt.subplots(nrows=samples, ncols=2, figsize=(10, 24))\n",
    "    for i in range(samples):\n",
    "        image, mask = dataset[idx]\n",
    "        ax[i, 0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n",
    "        ax[i, 1].imshow(mask, interpolation=\"nearest\")\n",
    "        ax[i, 0].set_title(\"Augmented image\")\n",
    "        ax[i, 1].set_title(\"Augmented mask\")\n",
    "        ax[i, 0].set_axis_off()\n",
    "        ax[i, 1].set_axis_off()\n",
    "        \n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "def remove_Black_Border_mask(image, ROI_mask:np.ndarray=None):\n",
    "        copyImg = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV)\n",
    "        h = copyImg[:,:,0]\n",
    "        mask = np.ones(h.shape, dtype=np.uint8) * 255\n",
    "        th = (25, 175)\n",
    "        mask[(h > th[0]) & (h < th[1])] = 0\n",
    "        copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)\n",
    "        resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)\n",
    "            \n",
    "        image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)\n",
    "        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)\n",
    "        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))\n",
    "        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)\n",
    "        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
    "        contours = contours[0] if len(contours) == 2 else contours[1]\n",
    "        bigCont = max(contours, key=cv2.contourArea)\n",
    "        x, y, w, h = cv2.boundingRect(bigCont)\n",
    "        return image[y : y + h, x : x + w], ROI_mask[y : y + h, x : x + w] if ROI_mask is not None else None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean, std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])\n",
    "    \n",
    "setup_seed(2023)\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "batchSize = 16\n",
    "numWorkers = 5\n",
    "rootFile = Path(\"../dataset/ROI/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "traintransform = A.Compose(\n",
    "    [\n",
    "        A.Resize(224, 224, interpolation=cv2.INTER_CUBIC), \n",
    "        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),\n",
    "        A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),\n",
    "        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),\n",
    "        A.Rotate(limit=(-30,30), p=0.9),\n",
    "        A.RandomGamma(p=0.7),\n",
    "        A.RandomFog(0.1, p=0.75),\n",
    "        A.RandomToneCurve(p=0.75),\n",
    "        A.Normalize(mean, std),\n",
    "        ToTensorV2()\n",
    "    ]\n",
    ")\n",
    "\n",
    "valtransform = A.Compose(\n",
    "    [\n",
    "        A.Resize(224, 224, interpolation=cv2.INTER_CUBIC),\n",
    "        A.Normalize(mean, std),\n",
    "        ToTensorV2()\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainDataset = Loaders.RARP_DatasetFolder_RoiExtractor(\n",
    "    str(rootFile/\"train\"),\n",
    "    loader=defs.load_Img,\n",
    "    extensions=\"tiff\",\n",
    "    transform=traintransform,\n",
    "    create_mask=True\n",
    ")\n",
    "valDataset = Loaders.RARP_DatasetFolder_RoiExtractor(\n",
    "    str(rootFile/\"val\"),\n",
    "    loader=defs.load_Img,\n",
    "    extensions=\"tiff\",\n",
    "    transform=valtransform,\n",
    "    create_mask=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "setup_seed(2025)\n",
    "visualize_augmentations(trainDataset, idx=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Train_DataLoader = DataLoader(\n",
    "    trainDataset, \n",
    "    batch_size=batchSize, \n",
    "    num_workers=numWorkers, \n",
    "    shuffle=True, \n",
    "    persistent_workers=numWorkers>0\n",
    ")\n",
    "\n",
    "Val_DataLoader = DataLoader(\n",
    "    valDataset, \n",
    "    batch_size=batchSize, \n",
    "    num_workers=numWorkers, \n",
    "    shuffle=False, \n",
    "    pin_memory=True,\n",
    "    persistent_workers=numWorkers>0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UNet_RN18(torch.nn.Module):\n",
    "    def _conv_block(self, in_ch, out_ch):\n",
    "        return torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n",
    "            torch.nn.SiLU(inplace=True),\n",
    "            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n",
    "            torch.nn.SiLU(inplace=True),\n",
    "        )\n",
    "    \n",
    "    def _hook_fn(self, module, input, output):\n",
    "        self.feature_maps.append(output)    \n",
    "    \n",
    "    def _register_encoder_hooks(self):\n",
    "        for layer in self.list_blocks:\n",
    "            self.hooks.append(layer.register_forward_hook(self._hook_fn))\n",
    "    \n",
    "    def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        \n",
    "        self.hooks = []\n",
    "        self.feature_maps = []\n",
    "        \n",
    "        self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
    "        self.encoder_base.fc = torch.nn.Identity()\n",
    "        \n",
    "        #for parms in self.encoder_base.parameters():\n",
    "        #    parms.requires_grad = False\n",
    "        \n",
    "        self.list_blocks = [\n",
    "            self.encoder_base.conv1,\n",
    "            self.encoder_base.layer1,\n",
    "            self.encoder_base.layer2,\n",
    "            self.encoder_base.layer3,\n",
    "            self.encoder_base.layer4\n",
    "        ]\n",
    "        \n",
    "        self._register_encoder_hooks()\n",
    "        \n",
    "        self.dropout = torch.nn.Dropout2d(0.4)\n",
    "        \n",
    "        self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)\n",
    "        self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n",
    "        self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)\n",
    "        self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)\n",
    "        \n",
    "        self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)\n",
    "\n",
    "        self.decoder_0 = self._conv_block(512, 256)     #0\n",
    "        self.decoder_1 = self._conv_block(256, 128)     #1\n",
    "        self.decoder_2 = self._conv_block(128, 64)      #2\n",
    "        self.decoder_3 = self._conv_block(32 + 64, 32)  #3\n",
    "        \n",
    "        self.decoder_extra = self._conv_block(16, 16)\n",
    "        \n",
    "        self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        self.feature_maps = []\n",
    "        \n",
    "        _ = self.encoder_base(x) # forward to encoder and call hooks \n",
    "        \n",
    "        encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps\n",
    "                \n",
    "        decoder_l3 = self.upConv_0(btlneck) \n",
    "        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) \n",
    "        decoder_l3 = self.decoder_0(decoder_l3)    \n",
    "        \n",
    "        decoder_l3 = self.dropout(decoder_l3)\n",
    "        \n",
    "        decoder_l2 = self.upConv_1(decoder_l3) \n",
    "        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) \n",
    "        decoder_l2 = self.decoder_1(decoder_l2)  \n",
    "        \n",
    "        decoder_l2 = self.dropout(decoder_l2)\n",
    "        \n",
    "        decoder_l1 = self.upConv_2(decoder_l2) \n",
    "        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) \n",
    "        decoder_l1 = self.decoder_2(decoder_l1)  \n",
    "        \n",
    "        decoder_l1 = self.dropout(decoder_l1)\n",
    "        \n",
    "        decoder_l0 = self.upConv_3(decoder_l1) \n",
    "        decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1) \n",
    "        decoder_l0 = self.decoder_3(decoder_l0)\n",
    "        \n",
    "        decoder_last = self.upConv_extra(decoder_l0)\n",
    "        decoder_last = self.decoder_extra(decoder_last)\n",
    "                \n",
    "        return self.last_conv(decoder_last)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UNet_VAN_1(torch.nn.Module):\n",
    "    def _conv_block(self, in_ch, out_ch):\n",
    "        return torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n",
    "            torch.nn.ReLU(inplace=True),\n",
    "            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n",
    "            torch.nn.ReLU(inplace=True)\n",
    "        )\n",
    "    \n",
    "    def _hook_fn(self, module, input, output):\n",
    "        self.feature_maps.append(output)    \n",
    "    \n",
    "    def _register_encoder_hooks(self):\n",
    "        for layer in self.list_blocks:\n",
    "            self.hooks.append(layer.register_forward_hook(self._hook_fn))\n",
    "    \n",
    "    def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        \n",
    "        \n",
    "        self.hooks = []\n",
    "        self.feature_maps = []\n",
    "        \n",
    "        self.encoder_base = van.van_b1(pretrained = True, num_classes = 0)\n",
    "        \n",
    "        self.list_blocks = [\n",
    "            self.encoder_base.block1[-1],\n",
    "            self.encoder_base.block2[-1],\n",
    "            self.encoder_base.block3[-1],\n",
    "            self.encoder_base.block4[-1],\n",
    "        ]\n",
    "        \n",
    "        self._register_encoder_hooks()\n",
    "        \n",
    "        self.upConv_0 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2)\n",
    "        self.decoder_0 = self._conv_block(640, 256)\n",
    "        \n",
    "        self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n",
    "        self.decoder_1 = self._conv_block(256, 128)     #1\n",
    "        \n",
    "        self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)\n",
    "        self.decoder_2 = self._conv_block(128, 64)      #2\n",
    "                \n",
    "        self.last_conv = torch.nn.Conv2d(64, out_channels, kernel_size=1)\n",
    "                \n",
    "    def forward(self, x):\n",
    "        self.feature_maps = []\n",
    "        \n",
    "        _, _, h, w = x.shape\n",
    "        \n",
    "        _ = self.encoder_base(x) # forward to encoder and call hooks \n",
    "        \n",
    "        encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps\n",
    "                \n",
    "        decoder_l3 = self.upConv_0(btlneck) \n",
    "        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) \n",
    "        decoder_l3 = self.decoder_0(decoder_l3)    \n",
    "        \n",
    "        decoder_l2 = self.upConv_1(decoder_l3) \n",
    "        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) \n",
    "        decoder_l2 = self.decoder_1(decoder_l2)  \n",
    "        \n",
    "        decoder_l1 = self.upConv_2(decoder_l2) \n",
    "        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) \n",
    "        decoder_l1 = self.decoder_2(decoder_l1)  \n",
    "                                \n",
    "        return self.last_conv(torch.nn.functional.interpolate(decoder_l1, size=(h, w), mode=\"nearest\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UNet(torch.nn.Module):\n",
    "    def _conv_block(self, in_ch, out_ch):\n",
    "        return torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n",
    "            torch.nn.ReLU(inplace=True),\n",
    "            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n",
    "            torch.nn.ReLU(inplace=True)\n",
    "        )\n",
    "    \n",
    "    def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        \n",
    "        self.encoder_0 = self._conv_block(in_channels, 64)  #0\n",
    "        self.encoder_1 = self._conv_block(64, 128)          #1\n",
    "        self.encoder_2 = self._conv_block(128, 256)         #2\n",
    "        self.encoder_3 = self._conv_block(256, 512)         #3\n",
    "        \n",
    "        self.pooling = torch.nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        \n",
    "        self.bottleneck = self._conv_block(512, 1024)\n",
    "        \n",
    "        self.upConv_0 = torch.nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)\n",
    "        self.upConv_1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)\n",
    "        self.upConv_2 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n",
    "        self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)\n",
    "\n",
    "        self.decoder_0 = self._conv_block(1024, 512)    #0\n",
    "        self.decoder_1 = self._conv_block(512, 256)     #1\n",
    "        self.decoder_2 = self._conv_block(256, 128)     #2\n",
    "        self.decoder_3 = self._conv_block(128, 64)      #3\n",
    "        \n",
    "        self.last_conv = torch.nn.Conv2d(64, out_channels, kernel_size=1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        encoder_l1 = self.encoder_0(x) #3 -> 64\n",
    "        encoder_l2 = self.encoder_1(self.pooling(encoder_l1)) #64 -> 128\n",
    "        encoder_l3 = self.encoder_2(self.pooling(encoder_l2)) #128 -> 256\n",
    "        encoder_l4 = self.encoder_3(self.pooling(encoder_l3)) #256 -> 512\n",
    "        \n",
    "        btlneck = self.bottleneck(self.pooling(encoder_l4)) #512 -> 1024\n",
    "        \n",
    "        decoder_l4 = self.upConv_0(btlneck) #1024 -> 512\n",
    "        decoder_l4 = torch.cat((decoder_l4, encoder_l4), dim=1) \n",
    "        decoder_l4 = self.decoder_0(decoder_l4) #(512 + 512) -> 512     \n",
    "        \n",
    "        decoder_l3 = self.upConv_1(decoder_l4) #512 -> 256\n",
    "        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) \n",
    "        decoder_l3 = self.decoder_1(decoder_l3) #(256 + 256) -> 256 \n",
    "        \n",
    "        decoder_l2 = self.upConv_2(decoder_l3) #256 -> 128\n",
    "        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) \n",
    "        decoder_l2 = self.decoder_2(decoder_l2) #(128 + 128) -> 128 \n",
    "        \n",
    "        decoder_l1 = self.upConv_3(decoder_l2) #128 -> 64\n",
    "        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) \n",
    "        decoder_l1 = self.decoder_3(decoder_l1) #(64 + 64) -> 64 \n",
    "        \n",
    "        return self.last_conv(decoder_l1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchmetrics.classification\n",
    "\n",
    "\n",
    "class RARP_NVB_Model(L.LightningModule):\n",
    "    def __init__(self,*args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        \n",
    "        self.model = UNet_RN18(in_channels=3, out_channels=1)\n",
    "               \n",
    "        self.lr = 1E-4\n",
    "        self.Lambda_L1 = None\n",
    "        self.lossFN = torch.nn.BCEWithLogitsLoss()\n",
    "        \n",
    "        self.train_IoU = torchmetrics.classification.BinaryJaccardIndex()\n",
    "        self.val_IoU = torchmetrics.classification.BinaryJaccardIndex()\n",
    "                \n",
    "    def forward(self, data):\n",
    "        data = data.float()\n",
    "        pred = self.model(data)\n",
    "        return pred\n",
    "    \n",
    "    def _shared_step(self, batch, val_step:bool = True):\n",
    "        img, mask = batch\n",
    "        \n",
    "        mask = mask.float()\n",
    "        mask = mask.unsqueeze(1)\n",
    "        prediction = self(img)\n",
    "                \n",
    "        loss = self.lossFN(prediction, mask)\n",
    "        \n",
    "        predicted_labels = torch.sigmoid(prediction)\n",
    "        \n",
    "        if not val_step:\n",
    "            if self.Lambda_L1 is not None:\n",
    "                loss_l1 = 0\n",
    "                for name, params in self.model.named_parameters():\n",
    "                    if \"decoder\" in name or \"upConv\" in name: \n",
    "                        loss_l1 += torch.norm(params, p=1)\n",
    "                    \n",
    "                loss += self.Lambda_L1 * loss_l1\n",
    "            \n",
    "        return loss, mask, predicted_labels\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        loss, true_labels, predicted_labels = self._shared_step(batch, False)\n",
    "\n",
    "        self.train_IoU.update(predicted_labels, true_labels)\n",
    "        \n",
    "        self.log(\"train_loss\", loss, on_epoch=True)\n",
    "        self.log(\"train_acc_IoU\", self.train_IoU, on_epoch=True, on_step=False)\n",
    "\n",
    "        return loss\n",
    "    \n",
    "    def on_after_backward(self):\n",
    "        total_norm = 0.0\n",
    "        for p in self.parameters():\n",
    "            if p.grad is not None:\n",
    "                param_norm = p.grad.data.norm(2)\n",
    "                total_norm += param_norm.item() ** 2\n",
    "        total_norm = total_norm ** 0.5\n",
    "        \n",
    "        self.log(\"grad_norm\", total_norm)\n",
    "        \n",
    "        if total_norm < 1e-8:\n",
    "            self.log(\"grad_warning\", \"Vanishing gradient suspected!\")\n",
    "    \n",
    "    def on_train_epoch_start(self):\n",
    "        pass\n",
    "        #for parms in self.model.encoder_base.parameters():\n",
    "        #    parms.requires_grad = (self.current_epoch % 2 == 0)\n",
    "            \n",
    "    \n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        loss, true_labels, predicted_labels = self._shared_step(batch)\n",
    "        \n",
    "        self.val_IoU.update(predicted_labels, true_labels)\n",
    "        \n",
    "        self.log(\"val_loss\", loss, on_epoch=True, on_step=False, prog_bar=True)\n",
    "        self.log(\"val_acc_IoU\", self.val_IoU, on_epoch=True, on_step=False)\n",
    "        \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) \n",
    "                \n",
    "        return [optimizer]\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Model = RARP_NVB_Model()\n",
    "setup_seed(2023)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = L.Trainer(\n",
    "    deterministic=True,\n",
    "    accelerator='gpu', \n",
    "    devices=1, \n",
    "    logger=TensorBoardLogger(save_dir=\"./logs_debug\"),\n",
    "    log_every_n_steps=5,   \n",
    "    callbacks=[callbk.ModelCheckpoint(monitor=\"val_acc_IoU\", filename=\"RARP-{epoch}-{val_loss:.4f}\", save_top_k=5, mode='max')],\n",
    "    max_epochs=100,\n",
    ")\n",
    "print(\"Train Phase\")\n",
    "trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainLog = [13, 97, 0.2795]\n",
    "#trainLog = [6, 36, 0.02]\n",
    "pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]}.ckpt\")\n",
    "Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
    "\n",
    "Model.to(device)\n",
    "Model.eval()\n",
    "\n",
    "pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/321.tiff\")\n",
    "frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
    "frameToFind, _ = remove_Black_Border_mask(frameToFind)\n",
    "frameToFind = valtransform(image=frameToFind)\n",
    "frameToFind = frameToFind[\"image\"]\n",
    "frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
    "frameToFind = frameToFind.to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    pred = Model(frameToFind)\n",
    "    pred = torch.sigmoid(pred)\n",
    "\n",
    "sample = Denorlalize(frameToFind[0].cpu(), std, mean)\n",
    "rgb_mask = np.zeros_like(sample)    \n",
    "pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
    "rgb_mask[pred[:, :, 0] > 0.5] = [0, 1, 0]\n",
    "    \n",
    "fig, ax = plt.subplots(1, 3, figsize=(15, 20))\n",
    "\n",
    "crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
    "\n",
    "ax[0].imshow(sample)\n",
    "ax[0].axis(\"off\")\n",
    "\n",
    "ax[1].imshow(rgb_mask)\n",
    "ax[1].axis(\"off\")\n",
    "\n",
    "ax[2].imshow(crop_imge_masked)\n",
    "ax[2].axis(\"off\")\n",
    "\n",
    "print(frameToFind.shape, pred.shape)\n",
    "\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainLog = [23, 64, 0.3028]\n",
    "#trainLog = [6, 36, 0.02]\n",
    "pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
    "Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
    "\n",
    "Model.to(device)\n",
    "Model.eval()\n",
    "\n",
    "pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/321.tiff\")\n",
    "frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
    "frameToFind, _ = remove_Black_Border_mask(frameToFind)\n",
    "frameToFind = valtransform(image=frameToFind)\n",
    "frameToFind = frameToFind[\"image\"]\n",
    "frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
    "frameToFind = frameToFind.to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    pred = Model(frameToFind)\n",
    "    pred = torch.sigmoid(pred)\n",
    "\n",
    "sample = Denorlalize(frameToFind[0].cpu(), std, mean)\n",
    "rgb_mask = np.zeros_like(sample)    \n",
    "pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
    "rgb_mask[pred[:, :, 0] > 0.5] = [0, 1, 0]\n",
    "    \n",
    "fig, ax = plt.subplots(1, 3, figsize=(15, 20))\n",
    "\n",
    "crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
    "\n",
    "ax[0].imshow(sample)\n",
    "ax[0].axis(\"off\")\n",
    "\n",
    "ax[1].imshow(rgb_mask)\n",
    "ax[1].axis(\"off\")\n",
    "\n",
    "ax[2].imshow(crop_imge_masked)\n",
    "ax[2].axis(\"off\")\n",
    "\n",
    "print(frameToFind.shape, pred.shape)\n",
    "\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainLog = [28, 93, 0.3467]\n",
    "#trainLog = [6, 36, 0.02]\n",
    "pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
    "Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
    "\n",
    "Model.to(device)\n",
    "Model.eval()\n",
    "\n",
    "pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/321.tiff\")\n",
    "frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
    "frameToFind, _ = remove_Black_Border_mask(frameToFind)\n",
    "frameToFind = valtransform(image=frameToFind)\n",
    "frameToFind = frameToFind[\"image\"]\n",
    "frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
    "frameToFind = frameToFind.to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    pred = Model(frameToFind)\n",
    "    pred = torch.sigmoid(pred)\n",
    "\n",
    "sample = Denorlalize(frameToFind[0].cpu(), std, mean)\n",
    "rgb_mask = np.zeros_like(sample)    \n",
    "pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
    "rgb_mask[pred[:, :, 0] > 0.5] = [0, 1, 0]\n",
    "    \n",
    "fig, ax = plt.subplots(1, 3, figsize=(15, 20))\n",
    "\n",
    "crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
    "\n",
    "ax[0].imshow(sample)\n",
    "ax[0].axis(\"off\")\n",
    "\n",
    "ax[1].imshow(rgb_mask)\n",
    "ax[1].axis(\"off\")\n",
    "\n",
    "ax[2].imshow(crop_imge_masked)\n",
    "ax[2].axis(\"off\")\n",
    "\n",
    "print(frameToFind.shape, pred.shape)\n",
    "\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def catmull_rom_spline(P0, P1, P2, P3, n_points=20):\n",
    "        points = []\n",
    "        for t in np.linspace(0, 1, n_points):\n",
    "            # Catmull-Rom formula\n",
    "            t2 = t * t\n",
    "            t3 = t2 * t\n",
    "            x = 0.5 * ((2 * P1[0]) +\n",
    "                    (-P0[0] + P2[0]) * t +\n",
    "                    (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t2 +\n",
    "                    (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t3)\n",
    "            \n",
    "            y = 0.5 * ((2 * P1[1]) +\n",
    "                    (-P0[1] + P2[1]) * t +\n",
    "                    (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t2 +\n",
    "                    (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t3)\n",
    "            \n",
    "            points.append((x, y))\n",
    "            \n",
    "        return points\n",
    "\n",
    "def catmull_rom_closed_loop(points, n_points=20):\n",
    "    spline_points = []\n",
    "    n = len(points)\n",
    "    \n",
    "    for i in range(n):\n",
    "        P0 = points[(i - 1) % n]\n",
    "        P1 = points[i]\n",
    "        P2 = points[(i + 1) % n]\n",
    "        P3 = points[(i + 2) % n]\n",
    "        spline_points += catmull_rom_spline(P0, P1, P2, P3, n_points)\n",
    "        \n",
    "    return np.array(spline_points)\n",
    "\n",
    "def create_mask_from_contour(spline_points:np, mask_size):\n",
    "    smooth_curve_int = np.round(spline_points).astype(np.int32)\n",
    "    mask = np.zeros(mask_size, dtype=np.uint8)\n",
    "    \n",
    "    return cv2.fillPoly(mask, [smooth_curve_int], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import jaccard_score\n",
    "\n",
    "#trainLog = [30, 87, 0.3375]\n",
    "trainLog = [30, 87, 0.3375]\n",
    "pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
    "\n",
    "Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
    "\n",
    "Model.to(device)\n",
    "Model.eval()\n",
    "\n",
    "pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/300.tiff\")\n",
    "pathJson = Path(str(pathImg.absolute()).replace(\".tiff\", \".json\"))\n",
    "frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
    "\n",
    "archorPts = json.load(open(pathJson))\n",
    "archorPts = np.array(archorPts[\"shapes\"][0][\"points\"])\n",
    "\n",
    "h, w, _ = frameToFind.shape\n",
    "smood_perimeter = catmull_rom_closed_loop(archorPts, n_points=15)\n",
    "roi_mask = create_mask_from_contour(smood_perimeter, (h, w))\n",
    "\n",
    "frameToFind, roi_mask = remove_Black_Border_mask(frameToFind, roi_mask)\n",
    "\n",
    "sample = frameToFind.copy()\n",
    "sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)\n",
    "input_h, input_w, _ = sample.shape  \n",
    "\n",
    "frameToFind = valtransform(image=frameToFind)\n",
    "frameToFind = frameToFind[\"image\"]\n",
    "frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
    "frameToFind = frameToFind.to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    pred = Model(frameToFind)\n",
    "    pred = torch.sigmoid(pred)\n",
    "\n",
    "_, _, h, w= frameToFind.shape\n",
    "\n",
    "pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
    "rgb_mask = np.zeros((h, w, 3)).astype(np.uint8)\n",
    "rgb_mask[pred[:, :, 0] > 0.5] = [0, 255, 0]\n",
    "\n",
    "pred = cv2.resize(pred, (input_w, input_h), interpolation=cv2.INTER_CUBIC)\n",
    "rgb_mask = cv2.resize(rgb_mask, (input_w, input_h), interpolation=cv2.INTER_CUBIC)\n",
    "crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
    "\n",
    "roi_maskRGB = np.zeros_like(rgb_mask)\n",
    "roi_maskRGB[roi_mask[:, :] == 1] = [40, 0, 255]\n",
    "#crop_imge_masked_gt = cv2.addWeighted(roi_maskRGB, 1, rgb_mask, 0.4, 0)\n",
    "\n",
    "crop_imge_masked_gt = cv2.addWeighted(crop_imge_masked, 1, roi_maskRGB, 0.4, 0)\n",
    "\n",
    "ytrue = roi_maskRGB[..., 2]//255\n",
    "ypred = rgb_mask[..., 1]//255\n",
    "print(np.max(ytrue), ytrue.shape)\n",
    "print(np.max(ypred), ypred.shape)\n",
    "\n",
    "iou = jaccard_score(ytrue, ypred, average=\"micro\")\n",
    "\n",
    "fig, ax = plt.subplots(4, 2, figsize=(15, 25))\n",
    "\n",
    "ax[0, 0].set_title(\"Input image\")\n",
    "ax[0, 0].imshow(sample)\n",
    "ax[0, 0].axis(\"off\")\n",
    "\n",
    "ax[0, 1].set_title(\"Output Prediction mask\")\n",
    "maskDiplay = ax[0, 1].imshow(pred, cmap=\"inferno\",)\n",
    "fig.colorbar(maskDiplay, ax=ax[0, 1], label=\"Confidence\", fraction=0.0387, pad=0.01)\n",
    "ax[0, 1].axis(\"off\")\n",
    "\n",
    "ax[1, 0].set_title(\"Thresholded mask P(x) > 0.5\")\n",
    "ax[1, 0].imshow(rgb_mask)\n",
    "ax[1, 0].axis(\"off\")\n",
    "\n",
    "ax[1, 1].set_title(\"Result\")\n",
    "ax[1, 1].imshow(crop_imge_masked)\n",
    "ax[1, 1].axis(\"off\")\n",
    "\n",
    "ax[3, 0].set_title(\"GT mask\")\n",
    "ax[3, 0].imshow(roi_maskRGB)\n",
    "ax[3, 0].axis(\"off\")\n",
    "\n",
    "#crop_imge_masked_gt\n",
    "\n",
    "ax[3, 1].set_title(\"Result + GT orverlay\")\n",
    "ax[3, 1].imshow(crop_imge_masked_gt)\n",
    "ax[3, 1].axis(\"off\")\n",
    "ax[3, 1].text(0, 1400, f\"Acc.:{iou:.4}\", fontsize = 22, color = 'w',bbox = dict(facecolor = 'red', alpha = 0.5))\n",
    "\n",
    "_, thresh = cv2.threshold(rgb_mask[..., 1], 0, 255, cv2.THRESH_BINARY)\n",
    "kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))\n",
    "morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)\n",
    "\n",
    "morph = cv2.GaussianBlur(morph, (5, 5), 0)\n",
    "\n",
    "contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
    "contours = contours[0] if len(contours) == 2 else contours[1]\n",
    "\n",
    "roi = max(contours, key=cv2.contourArea)\n",
    "x, y, w, h = cv2.boundingRect(roi)\n",
    "\n",
    "ax[2, 0].set_title(\"Crop Thresholded mask P(x) > 0.5\")\n",
    "ax[2, 0].imshow(morph[y : y + h, x : x + w], cmap = \"gray\")\n",
    "ax[2, 0].axis(\"off\")\n",
    "\n",
    "ax[2, 1].set_title(\"Crop Result\")\n",
    "ax[2, 1].imshow(cv2.bitwise_and(sample, sample, mask=morph)[y : y + h, x : x + w])\n",
    "ax[2, 1].axis(\"off\")\n",
    "\n",
    "print(frameToFind.shape, pred.shape)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dice_score(mask1: np.ndarray, mask2: np.ndarray) -> float:\n",
    "    \"\"\"\n",
    "    Calculate the DICE (Dice Similarity Coefficient) between two binary masks.\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    mask1 : np.ndarray\n",
    "        First binary mask (values must be 0 or 1).\n",
    "    mask2 : np.ndarray\n",
    "        Second binary mask (values must be 0 or 1).\n",
    "\n",
    "    Returns:\n",
    "    --------\n",
    "    float\n",
    "        Dice score between mask1 and mask2.\n",
    "    \"\"\"\n",
    "    # Ensure the two masks are of the same shape\n",
    "    assert mask1.shape == mask2.shape, \"Masks must have the same shape\"\n",
    "    \n",
    "    # Convert to boolean if necessary (True where mask > 0)\n",
    "    mask1_bool = mask1.astype(bool)\n",
    "    mask2_bool = mask2.astype(bool)\n",
    "\n",
    "    # Calculate intersection: number of elements where both masks are 1\n",
    "    intersection = np.logical_and(mask1_bool, mask2_bool).sum()\n",
    "\n",
    "    # Calculate the sum of each mask\n",
    "    mask1_sum = mask1_bool.sum()\n",
    "    mask2_sum = mask2_bool.sum()\n",
    "\n",
    "    # Handle edge case to avoid division by zero\n",
    "    if mask1_sum + mask2_sum == 0:\n",
    "        # If both masks are entirely zero, define Dice as 1 or 0 depending on convention\n",
    "        return 1.0 if (mask1_sum == 0 and mask2_sum == 0) else 0.0\n",
    "\n",
    "    # Calculate the DICE score\n",
    "    dice = 2.0 * intersection / (mask1_sum + mask2_sum)\n",
    "    return dice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import jaccard_score\n",
    "\n",
    "trainLog = [32, 86, 0.3223]\n",
    "pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
    "Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
    "\n",
    "Model.to(device)\n",
    "Model.eval()\n",
    "avg_IoU = 0.0\n",
    "avg_DICE = 0.0\n",
    "\n",
    "for num_case in range (291, 316):\n",
    "    pathImg = Path(f\"D:/Users/user/Downloads/dataset/RARP291-340/{num_case}.tiff\")\n",
    "    pathJson = Path(str(pathImg.absolute()).replace(\".tiff\", \".json\"))\n",
    "    frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
    "\n",
    "    archorPts = json.load(open(pathJson))\n",
    "    archorPts = np.array(archorPts[\"shapes\"][0][\"points\"])\n",
    "\n",
    "    h, w, _ = frameToFind.shape\n",
    "    smood_perimeter = catmull_rom_closed_loop(archorPts, n_points=15)\n",
    "    roi_mask = create_mask_from_contour(smood_perimeter, (h, w))\n",
    "\n",
    "    frameToFind, roi_mask = remove_Black_Border_mask(frameToFind, roi_mask)\n",
    "\n",
    "    input_h, input_w, _ = frameToFind.shape  \n",
    "\n",
    "    frameToFind = valtransform(image=frameToFind)\n",
    "    frameToFind = frameToFind[\"image\"]\n",
    "    frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
    "    frameToFind = frameToFind.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        pred = Model(frameToFind)\n",
    "        pred = torch.sigmoid(pred)\n",
    "        \n",
    "    _, _, h, w= frameToFind.shape\n",
    "\n",
    "    pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
    "    rgb_mask = np.zeros((h, w, 3)).astype(np.uint8)\n",
    "    rgb_mask[pred[:, :, 0] > 0.5] = [0, 255, 0]\n",
    "    rgb_mask = cv2.resize(rgb_mask, (input_w, input_h), interpolation=cv2.INTER_CUBIC)\n",
    "\n",
    "    roi_maskRGB = np.zeros_like(rgb_mask)\n",
    "    roi_maskRGB[roi_mask[:, :] == 1] = [0, 0, 255]\n",
    "\n",
    "    ytrue = roi_maskRGB[..., 2]//255\n",
    "    ypred = rgb_mask[..., 1]//255\n",
    "\n",
    "    iou = jaccard_score(ytrue, ypred, average=\"micro\")\n",
    "    dice = dice_score(ytrue, ypred)\n",
    "    print(f\"{num_case} {iou:.4} {dice:.4}\")\n",
    "    avg_IoU += iou\n",
    "    avg_DICE += dice\n",
    "    \n",
    "print(f\"Avg {avg_IoU/25:.4} {avg_DICE/25:.4}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import jaccard_score\n",
    "\n",
    "trainLog = [30, 97, 0.2978]\n",
    "pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
    "Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
    "\n",
    "Model.to(device)\n",
    "Model.eval()\n",
    "\n",
    "pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/303.tiff\")\n",
    "pathJson = Path(str(pathImg.absolute()).replace(\".tiff\", \".json\"))\n",
    "frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
    "\n",
    "archorPts = json.load(open(pathJson))\n",
    "archorPts = np.array(archorPts[\"shapes\"][0][\"points\"])\n",
    "\n",
    "h, w, _ = frameToFind.shape\n",
    "smood_perimeter = catmull_rom_closed_loop(archorPts, n_points=15)\n",
    "roi_mask = create_mask_from_contour(smood_perimeter, (h, w))\n",
    "\n",
    "frameToFind, roi_mask = remove_Black_Border_mask(frameToFind, roi_mask)\n",
    "\n",
    "sample = frameToFind.copy()\n",
    "sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)\n",
    "input_h, input_w, _ = sample.shape  \n",
    "\n",
    "frameToFind = valtransform(image=frameToFind)\n",
    "frameToFind = frameToFind[\"image\"]\n",
    "frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
    "frameToFind = frameToFind.to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    pred = Model(frameToFind)\n",
    "    pred = torch.sigmoid(pred)\n",
    "\n",
    "_, _, h, w= frameToFind.shape\n",
    "\n",
    "pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
    "rgb_mask = np.zeros((h, w, 3)).astype(np.uint8)\n",
    "rgb_mask[pred[:, :, 0] > 0.5] = [0, 255, 0]\n",
    "rgb_mask = cv2.resize(rgb_mask, (input_w, input_h), interpolation=cv2.INTER_CUBIC)\n",
    "\n",
    "roi_maskRGB = np.zeros_like(rgb_mask)\n",
    "roi_maskRGB[roi_mask[:, :] == 1] = [0, 0, 255]\n",
    "\n",
    "crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
    "crop_imge_masked_gt = cv2.addWeighted(crop_imge_masked, 1, roi_maskRGB, 0.4, 0)\n",
    "\n",
    "#crop_imge_masked_gt = cv2.addWeighted(roi_maskRGB, 1, rgb_mask, 0.4, 0)\n",
    "\n",
    "ytrue = roi_maskRGB[..., 2]//255\n",
    "ypred = rgb_mask[..., 1]//255\n",
    "\n",
    "iou = jaccard_score(ytrue, ypred, average=\"micro\")\n",
    "dice = dice_score(ytrue, ypred)\n",
    "print(f\"IoU: {iou:.4}\")\n",
    "print(f\"DICE: {dice:.4}\")\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, 15))\n",
    "\n",
    "ax.set_title(f\"Result {pathImg.name}\")\n",
    "ax.imshow(crop_imge_masked_gt)\n",
    "ax.axis(\"off\")\n",
    "ax.text(0, 1400, f\"IoU:{iou:.4}; DICE:{dice:.4}\", fontsize = 22, color = 'w',bbox = dict(facecolor = 'red', alpha = 0.5))\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pyRARP",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}