{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchmetrics\n",
"import torchvision\n",
"import cv2\n",
"import numpy as np\n",
"import albumentations as A\n",
"from albumentations.pytorch import ToTensorV2\n",
"from torch.utils.data import DataLoader\n",
"from lightning.pytorch import seed_everything\n",
"import lightning.pytorch.callbacks as callbk\n",
"from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger\n",
"from tqdm.notebook import tqdm\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"import Loaders\n",
"import defs\n",
"import lightning as L\n",
"import json\n",
"import copy\n",
"from torch.optim.lr_scheduler import LambdaLR\n",
"import van\n",
"\n",
"def setup_seed(seed):\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
" np.random.seed(seed)\n",
" seed_everything(seed, workers=True)\n",
" torch.backends.cudnn.deterministic = True\n",
" \n",
"def Denorlalize (img:torch.Tensor, std, mean):\n",
" ImgNumpy = img.numpy().transpose((1, 2, 0))\n",
" ImgNumpy = np.clip((std * ImgNumpy + mean) , 0, 1)\n",
" ImgNumpy = ImgNumpy[...,::-1].copy()\n",
" \n",
" return ImgNumpy\n",
"\n",
"def visualize_augmentations(dataset, idx=0, samples=5):\n",
" dataset = copy.deepcopy(dataset)\n",
" dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])\n",
" _, ax = plt.subplots(nrows=samples, ncols=2, figsize=(10, 24))\n",
" for i in range(samples):\n",
" image, mask = dataset[idx]\n",
" ax[i, 0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n",
" ax[i, 1].imshow(mask, interpolation=\"nearest\")\n",
" ax[i, 0].set_title(\"Augmented image\")\n",
" ax[i, 1].set_title(\"Augmented mask\")\n",
" ax[i, 0].set_axis_off()\n",
" ax[i, 1].set_axis_off()\n",
" \n",
" plt.tight_layout()\n",
" plt.show()\n",
" \n",
"def remove_Black_Border_mask(image, ROI_mask:np.ndarray=None):\n",
" copyImg = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV)\n",
" h = copyImg[:,:,0]\n",
" mask = np.ones(h.shape, dtype=np.uint8) * 255\n",
" th = (25, 175)\n",
" mask[(h > th[0]) & (h < th[1])] = 0\n",
" copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)\n",
" resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)\n",
" \n",
" image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)\n",
" _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)\n",
" kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))\n",
" morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)\n",
" contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
" contours = contours[0] if len(contours) == 2 else contours[1]\n",
" bigCont = max(contours, key=cv2.contourArea)\n",
" x, y, w, h = cv2.boundingRect(bigCont)\n",
" return image[y : y + h, x : x + w], ROI_mask[y : y + h, x : x + w] if ROI_mask is not None else None\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mean, std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])\n",
" \n",
"setup_seed(2023)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"batchSize = 16\n",
"numWorkers = 5\n",
"rootFile = Path(\"../dataset/ROI/\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"traintransform = A.Compose(\n",
" [\n",
" A.Resize(224, 224, interpolation=cv2.INTER_CUBIC), \n",
" A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),\n",
" A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),\n",
" A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),\n",
" A.Rotate(limit=(-30,30), p=0.9),\n",
" A.RandomGamma(p=0.7),\n",
" A.RandomFog(0.1, p=0.75),\n",
" A.RandomToneCurve(p=0.75),\n",
" A.Normalize(mean, std),\n",
" ToTensorV2()\n",
" ]\n",
")\n",
"\n",
"valtransform = A.Compose(\n",
" [\n",
" A.Resize(224, 224, interpolation=cv2.INTER_CUBIC),\n",
" A.Normalize(mean, std),\n",
" ToTensorV2()\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainDataset = Loaders.RARP_DatasetFolder_RoiExtractor(\n",
" str(rootFile/\"train\"),\n",
" loader=defs.load_Img,\n",
" extensions=\"tiff\",\n",
" transform=traintransform,\n",
" create_mask=True\n",
")\n",
"valDataset = Loaders.RARP_DatasetFolder_RoiExtractor(\n",
" str(rootFile/\"val\"),\n",
" loader=defs.load_Img,\n",
" extensions=\"tiff\",\n",
" transform=valtransform,\n",
" create_mask=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"setup_seed(2025)\n",
"visualize_augmentations(trainDataset, idx=5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Train_DataLoader = DataLoader(\n",
" trainDataset, \n",
" batch_size=batchSize, \n",
" num_workers=numWorkers, \n",
" shuffle=True, \n",
" persistent_workers=numWorkers>0\n",
")\n",
"\n",
"Val_DataLoader = DataLoader(\n",
" valDataset, \n",
" batch_size=batchSize, \n",
" num_workers=numWorkers, \n",
" shuffle=False, \n",
" pin_memory=True,\n",
" persistent_workers=numWorkers>0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class UNet_RN18(torch.nn.Module):\n",
" def _conv_block(self, in_ch, out_ch):\n",
" return torch.nn.Sequential(\n",
" torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n",
" torch.nn.SiLU(inplace=True),\n",
" torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n",
" torch.nn.SiLU(inplace=True),\n",
" )\n",
" \n",
" def _hook_fn(self, module, input, output):\n",
" self.feature_maps.append(output) \n",
" \n",
" def _register_encoder_hooks(self):\n",
" for layer in self.list_blocks:\n",
" self.hooks.append(layer.register_forward_hook(self._hook_fn))\n",
" \n",
" def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n",
" \n",
" self.hooks = []\n",
" self.feature_maps = []\n",
" \n",
" self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
" self.encoder_base.fc = torch.nn.Identity()\n",
" \n",
" #for parms in self.encoder_base.parameters():\n",
" # parms.requires_grad = False\n",
" \n",
" self.list_blocks = [\n",
" self.encoder_base.conv1,\n",
" self.encoder_base.layer1,\n",
" self.encoder_base.layer2,\n",
" self.encoder_base.layer3,\n",
" self.encoder_base.layer4\n",
" ]\n",
" \n",
" self._register_encoder_hooks()\n",
" \n",
" self.dropout = torch.nn.Dropout2d(0.4)\n",
" \n",
" self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)\n",
" self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n",
" self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)\n",
" self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)\n",
" \n",
" self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)\n",
"\n",
" self.decoder_0 = self._conv_block(512, 256) #0\n",
" self.decoder_1 = self._conv_block(256, 128) #1\n",
" self.decoder_2 = self._conv_block(128, 64) #2\n",
" self.decoder_3 = self._conv_block(32 + 64, 32) #3\n",
" \n",
" self.decoder_extra = self._conv_block(16, 16)\n",
" \n",
" self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1)\n",
" \n",
" def forward(self, x):\n",
" self.feature_maps = []\n",
" \n",
" _ = self.encoder_base(x) # forward to encoder and call hooks \n",
" \n",
" encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps\n",
" \n",
" decoder_l3 = self.upConv_0(btlneck) \n",
" decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) \n",
" decoder_l3 = self.decoder_0(decoder_l3) \n",
" \n",
" decoder_l3 = self.dropout(decoder_l3)\n",
" \n",
" decoder_l2 = self.upConv_1(decoder_l3) \n",
" decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) \n",
" decoder_l2 = self.decoder_1(decoder_l2) \n",
" \n",
" decoder_l2 = self.dropout(decoder_l2)\n",
" \n",
" decoder_l1 = self.upConv_2(decoder_l2) \n",
" decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) \n",
" decoder_l1 = self.decoder_2(decoder_l1) \n",
" \n",
" decoder_l1 = self.dropout(decoder_l1)\n",
" \n",
" decoder_l0 = self.upConv_3(decoder_l1) \n",
" decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1) \n",
" decoder_l0 = self.decoder_3(decoder_l0)\n",
" \n",
" decoder_last = self.upConv_extra(decoder_l0)\n",
" decoder_last = self.decoder_extra(decoder_last)\n",
" \n",
" return self.last_conv(decoder_last)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class UNet_VAN_1(torch.nn.Module):\n",
" def _conv_block(self, in_ch, out_ch):\n",
" return torch.nn.Sequential(\n",
" torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n",
" torch.nn.ReLU(inplace=True),\n",
" torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n",
" torch.nn.ReLU(inplace=True)\n",
" )\n",
" \n",
" def _hook_fn(self, module, input, output):\n",
" self.feature_maps.append(output) \n",
" \n",
" def _register_encoder_hooks(self):\n",
" for layer in self.list_blocks:\n",
" self.hooks.append(layer.register_forward_hook(self._hook_fn))\n",
" \n",
" def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n",
" \n",
" \n",
" self.hooks = []\n",
" self.feature_maps = []\n",
" \n",
" self.encoder_base = van.van_b1(pretrained = True, num_classes = 0)\n",
" \n",
" self.list_blocks = [\n",
" self.encoder_base.block1[-1],\n",
" self.encoder_base.block2[-1],\n",
" self.encoder_base.block3[-1],\n",
" self.encoder_base.block4[-1],\n",
" ]\n",
" \n",
" self._register_encoder_hooks()\n",
" \n",
" self.upConv_0 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2)\n",
" self.decoder_0 = self._conv_block(640, 256)\n",
" \n",
" self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n",
" self.decoder_1 = self._conv_block(256, 128) #1\n",
" \n",
" self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)\n",
" self.decoder_2 = self._conv_block(128, 64) #2\n",
" \n",
" self.last_conv = torch.nn.Conv2d(64, out_channels, kernel_size=1)\n",
" \n",
" def forward(self, x):\n",
" self.feature_maps = []\n",
" \n",
" _, _, h, w = x.shape\n",
" \n",
" _ = self.encoder_base(x) # forward to encoder and call hooks \n",
" \n",
" encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps\n",
" \n",
" decoder_l3 = self.upConv_0(btlneck) \n",
" decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) \n",
" decoder_l3 = self.decoder_0(decoder_l3) \n",
" \n",
" decoder_l2 = self.upConv_1(decoder_l3) \n",
" decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) \n",
" decoder_l2 = self.decoder_1(decoder_l2) \n",
" \n",
" decoder_l1 = self.upConv_2(decoder_l2) \n",
" decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) \n",
" decoder_l1 = self.decoder_2(decoder_l1) \n",
" \n",
" return self.last_conv(torch.nn.functional.interpolate(decoder_l1, size=(h, w), mode=\"nearest\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class UNet(torch.nn.Module):\n",
" def _conv_block(self, in_ch, out_ch):\n",
" return torch.nn.Sequential(\n",
" torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n",
" torch.nn.ReLU(inplace=True),\n",
" torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n",
" torch.nn.ReLU(inplace=True)\n",
" )\n",
" \n",
" def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n",
" \n",
" self.encoder_0 = self._conv_block(in_channels, 64) #0\n",
" self.encoder_1 = self._conv_block(64, 128) #1\n",
" self.encoder_2 = self._conv_block(128, 256) #2\n",
" self.encoder_3 = self._conv_block(256, 512) #3\n",
" \n",
" self.pooling = torch.nn.MaxPool2d(kernel_size=2, stride=2)\n",
" \n",
" self.bottleneck = self._conv_block(512, 1024)\n",
" \n",
" self.upConv_0 = torch.nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)\n",
" self.upConv_1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)\n",
" self.upConv_2 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)\n",
" self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)\n",
"\n",
" self.decoder_0 = self._conv_block(1024, 512) #0\n",
" self.decoder_1 = self._conv_block(512, 256) #1\n",
" self.decoder_2 = self._conv_block(256, 128) #2\n",
" self.decoder_3 = self._conv_block(128, 64) #3\n",
" \n",
" self.last_conv = torch.nn.Conv2d(64, out_channels, kernel_size=1)\n",
" \n",
" def forward(self, x):\n",
" encoder_l1 = self.encoder_0(x) #3 -> 64\n",
" encoder_l2 = self.encoder_1(self.pooling(encoder_l1)) #64 -> 128\n",
" encoder_l3 = self.encoder_2(self.pooling(encoder_l2)) #128 -> 256\n",
" encoder_l4 = self.encoder_3(self.pooling(encoder_l3)) #256 -> 512\n",
" \n",
" btlneck = self.bottleneck(self.pooling(encoder_l4)) #512 -> 1024\n",
" \n",
" decoder_l4 = self.upConv_0(btlneck) #1024 -> 512\n",
" decoder_l4 = torch.cat((decoder_l4, encoder_l4), dim=1) \n",
" decoder_l4 = self.decoder_0(decoder_l4) #(512 + 512) -> 512 \n",
" \n",
" decoder_l3 = self.upConv_1(decoder_l4) #512 -> 256\n",
" decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) \n",
" decoder_l3 = self.decoder_1(decoder_l3) #(256 + 256) -> 256 \n",
" \n",
" decoder_l2 = self.upConv_2(decoder_l3) #256 -> 128\n",
" decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) \n",
" decoder_l2 = self.decoder_2(decoder_l2) #(128 + 128) -> 128 \n",
" \n",
" decoder_l1 = self.upConv_3(decoder_l2) #128 -> 64\n",
" decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) \n",
" decoder_l1 = self.decoder_3(decoder_l1) #(64 + 64) -> 64 \n",
" \n",
" return self.last_conv(decoder_l1)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torchmetrics.classification\n",
"\n",
"\n",
"class RARP_NVB_Model(L.LightningModule):\n",
" def __init__(self,*args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n",
" \n",
" self.model = UNet_RN18(in_channels=3, out_channels=1)\n",
" \n",
" self.lr = 1E-4\n",
" self.Lambda_L1 = None\n",
" self.lossFN = torch.nn.BCEWithLogitsLoss()\n",
" \n",
" self.train_IoU = torchmetrics.classification.BinaryJaccardIndex()\n",
" self.val_IoU = torchmetrics.classification.BinaryJaccardIndex()\n",
" \n",
" def forward(self, data):\n",
" data = data.float()\n",
" pred = self.model(data)\n",
" return pred\n",
" \n",
" def _shared_step(self, batch, val_step:bool = True):\n",
" img, mask = batch\n",
" \n",
" mask = mask.float()\n",
" mask = mask.unsqueeze(1)\n",
" prediction = self(img)\n",
" \n",
" loss = self.lossFN(prediction, mask)\n",
" \n",
" predicted_labels = torch.sigmoid(prediction)\n",
" \n",
" if not val_step:\n",
" if self.Lambda_L1 is not None:\n",
" loss_l1 = 0\n",
" for name, params in self.model.named_parameters():\n",
" if \"decoder\" in name or \"upConv\" in name: \n",
" loss_l1 += torch.norm(params, p=1)\n",
" \n",
" loss += self.Lambda_L1 * loss_l1\n",
" \n",
" return loss, mask, predicted_labels\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" loss, true_labels, predicted_labels = self._shared_step(batch, False)\n",
"\n",
" self.train_IoU.update(predicted_labels, true_labels)\n",
" \n",
" self.log(\"train_loss\", loss, on_epoch=True)\n",
" self.log(\"train_acc_IoU\", self.train_IoU, on_epoch=True, on_step=False)\n",
"\n",
" return loss\n",
" \n",
" def on_after_backward(self):\n",
" total_norm = 0.0\n",
" for p in self.parameters():\n",
" if p.grad is not None:\n",
" param_norm = p.grad.data.norm(2)\n",
" total_norm += param_norm.item() ** 2\n",
" total_norm = total_norm ** 0.5\n",
" \n",
" self.log(\"grad_norm\", total_norm)\n",
" \n",
" if total_norm < 1e-8:\n",
" self.log(\"grad_warning\", \"Vanishing gradient suspected!\")\n",
" \n",
" def on_train_epoch_start(self):\n",
" pass\n",
" #for parms in self.model.encoder_base.parameters():\n",
" # parms.requires_grad = (self.current_epoch % 2 == 0)\n",
" \n",
" \n",
" def validation_step(self, batch, batch_idx):\n",
" loss, true_labels, predicted_labels = self._shared_step(batch)\n",
" \n",
" self.val_IoU.update(predicted_labels, true_labels)\n",
" \n",
" self.log(\"val_loss\", loss, on_epoch=True, on_step=False, prog_bar=True)\n",
" self.log(\"val_acc_IoU\", self.val_IoU, on_epoch=True, on_step=False)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) \n",
" \n",
" return [optimizer]\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Model = RARP_NVB_Model()\n",
"setup_seed(2023)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = L.Trainer(\n",
" deterministic=True,\n",
" accelerator='gpu', \n",
" devices=1, \n",
" logger=TensorBoardLogger(save_dir=\"./logs_debug\"),\n",
" log_every_n_steps=5, \n",
" callbacks=[callbk.ModelCheckpoint(monitor=\"val_acc_IoU\", filename=\"RARP-{epoch}-{val_loss:.4f}\", save_top_k=5, mode='max')],\n",
" max_epochs=100,\n",
")\n",
"print(\"Train Phase\")\n",
"trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainLog = [13, 97, 0.2795]\n",
"#trainLog = [6, 36, 0.02]\n",
"pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]}.ckpt\")\n",
"Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
"\n",
"Model.to(device)\n",
"Model.eval()\n",
"\n",
"pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/321.tiff\")\n",
"frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
"frameToFind, _ = remove_Black_Border_mask(frameToFind)\n",
"frameToFind = valtransform(image=frameToFind)\n",
"frameToFind = frameToFind[\"image\"]\n",
"frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
"frameToFind = frameToFind.to(device)\n",
"\n",
"with torch.no_grad():\n",
" pred = Model(frameToFind)\n",
" pred = torch.sigmoid(pred)\n",
"\n",
"sample = Denorlalize(frameToFind[0].cpu(), std, mean)\n",
"rgb_mask = np.zeros_like(sample) \n",
"pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
"rgb_mask[pred[:, :, 0] > 0.5] = [0, 1, 0]\n",
" \n",
"fig, ax = plt.subplots(1, 3, figsize=(15, 20))\n",
"\n",
"crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
"\n",
"ax[0].imshow(sample)\n",
"ax[0].axis(\"off\")\n",
"\n",
"ax[1].imshow(rgb_mask)\n",
"ax[1].axis(\"off\")\n",
"\n",
"ax[2].imshow(crop_imge_masked)\n",
"ax[2].axis(\"off\")\n",
"\n",
"print(frameToFind.shape, pred.shape)\n",
"\n",
"\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainLog = [23, 64, 0.3028]\n",
"#trainLog = [6, 36, 0.02]\n",
"pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
"Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
"\n",
"Model.to(device)\n",
"Model.eval()\n",
"\n",
"pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/321.tiff\")\n",
"frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
"frameToFind, _ = remove_Black_Border_mask(frameToFind)\n",
"frameToFind = valtransform(image=frameToFind)\n",
"frameToFind = frameToFind[\"image\"]\n",
"frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
"frameToFind = frameToFind.to(device)\n",
"\n",
"with torch.no_grad():\n",
" pred = Model(frameToFind)\n",
" pred = torch.sigmoid(pred)\n",
"\n",
"sample = Denorlalize(frameToFind[0].cpu(), std, mean)\n",
"rgb_mask = np.zeros_like(sample) \n",
"pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
"rgb_mask[pred[:, :, 0] > 0.5] = [0, 1, 0]\n",
" \n",
"fig, ax = plt.subplots(1, 3, figsize=(15, 20))\n",
"\n",
"crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
"\n",
"ax[0].imshow(sample)\n",
"ax[0].axis(\"off\")\n",
"\n",
"ax[1].imshow(rgb_mask)\n",
"ax[1].axis(\"off\")\n",
"\n",
"ax[2].imshow(crop_imge_masked)\n",
"ax[2].axis(\"off\")\n",
"\n",
"print(frameToFind.shape, pred.shape)\n",
"\n",
"\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainLog = [28, 93, 0.3467]\n",
"#trainLog = [6, 36, 0.02]\n",
"pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
"Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
"\n",
"Model.to(device)\n",
"Model.eval()\n",
"\n",
"pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/321.tiff\")\n",
"frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
"frameToFind, _ = remove_Black_Border_mask(frameToFind)\n",
"frameToFind = valtransform(image=frameToFind)\n",
"frameToFind = frameToFind[\"image\"]\n",
"frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
"frameToFind = frameToFind.to(device)\n",
"\n",
"with torch.no_grad():\n",
" pred = Model(frameToFind)\n",
" pred = torch.sigmoid(pred)\n",
"\n",
"sample = Denorlalize(frameToFind[0].cpu(), std, mean)\n",
"rgb_mask = np.zeros_like(sample) \n",
"pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
"rgb_mask[pred[:, :, 0] > 0.5] = [0, 1, 0]\n",
" \n",
"fig, ax = plt.subplots(1, 3, figsize=(15, 20))\n",
"\n",
"crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
"\n",
"ax[0].imshow(sample)\n",
"ax[0].axis(\"off\")\n",
"\n",
"ax[1].imshow(rgb_mask)\n",
"ax[1].axis(\"off\")\n",
"\n",
"ax[2].imshow(crop_imge_masked)\n",
"ax[2].axis(\"off\")\n",
"\n",
"print(frameToFind.shape, pred.shape)\n",
"\n",
"\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def catmull_rom_spline(P0, P1, P2, P3, n_points=20):\n",
" points = []\n",
" for t in np.linspace(0, 1, n_points):\n",
" # Catmull-Rom formula\n",
" t2 = t * t\n",
" t3 = t2 * t\n",
" x = 0.5 * ((2 * P1[0]) +\n",
" (-P0[0] + P2[0]) * t +\n",
" (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t2 +\n",
" (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t3)\n",
" \n",
" y = 0.5 * ((2 * P1[1]) +\n",
" (-P0[1] + P2[1]) * t +\n",
" (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t2 +\n",
" (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t3)\n",
" \n",
" points.append((x, y))\n",
" \n",
" return points\n",
"\n",
"def catmull_rom_closed_loop(points, n_points=20):\n",
" spline_points = []\n",
" n = len(points)\n",
" \n",
" for i in range(n):\n",
" P0 = points[(i - 1) % n]\n",
" P1 = points[i]\n",
" P2 = points[(i + 1) % n]\n",
" P3 = points[(i + 2) % n]\n",
" spline_points += catmull_rom_spline(P0, P1, P2, P3, n_points)\n",
" \n",
" return np.array(spline_points)\n",
"\n",
"def create_mask_from_contour(spline_points:np, mask_size):\n",
" smooth_curve_int = np.round(spline_points).astype(np.int32)\n",
" mask = np.zeros(mask_size, dtype=np.uint8)\n",
" \n",
" return cv2.fillPoly(mask, [smooth_curve_int], 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import jaccard_score\n",
"\n",
"#trainLog = [30, 87, 0.3375]\n",
"trainLog = [30, 87, 0.3375]\n",
"pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
"\n",
"Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
"\n",
"Model.to(device)\n",
"Model.eval()\n",
"\n",
"pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/300.tiff\")\n",
"pathJson = Path(str(pathImg.absolute()).replace(\".tiff\", \".json\"))\n",
"frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
"\n",
"archorPts = json.load(open(pathJson))\n",
"archorPts = np.array(archorPts[\"shapes\"][0][\"points\"])\n",
"\n",
"h, w, _ = frameToFind.shape\n",
"smood_perimeter = catmull_rom_closed_loop(archorPts, n_points=15)\n",
"roi_mask = create_mask_from_contour(smood_perimeter, (h, w))\n",
"\n",
"frameToFind, roi_mask = remove_Black_Border_mask(frameToFind, roi_mask)\n",
"\n",
"sample = frameToFind.copy()\n",
"sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)\n",
"input_h, input_w, _ = sample.shape \n",
"\n",
"frameToFind = valtransform(image=frameToFind)\n",
"frameToFind = frameToFind[\"image\"]\n",
"frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
"frameToFind = frameToFind.to(device)\n",
"\n",
"with torch.no_grad():\n",
" pred = Model(frameToFind)\n",
" pred = torch.sigmoid(pred)\n",
"\n",
"_, _, h, w= frameToFind.shape\n",
"\n",
"pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
"rgb_mask = np.zeros((h, w, 3)).astype(np.uint8)\n",
"rgb_mask[pred[:, :, 0] > 0.5] = [0, 255, 0]\n",
"\n",
"pred = cv2.resize(pred, (input_w, input_h), interpolation=cv2.INTER_CUBIC)\n",
"rgb_mask = cv2.resize(rgb_mask, (input_w, input_h), interpolation=cv2.INTER_CUBIC)\n",
"crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
"\n",
"roi_maskRGB = np.zeros_like(rgb_mask)\n",
"roi_maskRGB[roi_mask[:, :] == 1] = [40, 0, 255]\n",
"#crop_imge_masked_gt = cv2.addWeighted(roi_maskRGB, 1, rgb_mask, 0.4, 0)\n",
"\n",
"crop_imge_masked_gt = cv2.addWeighted(crop_imge_masked, 1, roi_maskRGB, 0.4, 0)\n",
"\n",
"ytrue = roi_maskRGB[..., 2]//255\n",
"ypred = rgb_mask[..., 1]//255\n",
"print(np.max(ytrue), ytrue.shape)\n",
"print(np.max(ypred), ypred.shape)\n",
"\n",
"iou = jaccard_score(ytrue, ypred, average=\"micro\")\n",
"\n",
"fig, ax = plt.subplots(4, 2, figsize=(15, 25))\n",
"\n",
"ax[0, 0].set_title(\"Input image\")\n",
"ax[0, 0].imshow(sample)\n",
"ax[0, 0].axis(\"off\")\n",
"\n",
"ax[0, 1].set_title(\"Output Prediction mask\")\n",
"maskDiplay = ax[0, 1].imshow(pred, cmap=\"inferno\",)\n",
"fig.colorbar(maskDiplay, ax=ax[0, 1], label=\"Confidence\", fraction=0.0387, pad=0.01)\n",
"ax[0, 1].axis(\"off\")\n",
"\n",
"ax[1, 0].set_title(\"Thresholded mask P(x) > 0.5\")\n",
"ax[1, 0].imshow(rgb_mask)\n",
"ax[1, 0].axis(\"off\")\n",
"\n",
"ax[1, 1].set_title(\"Result\")\n",
"ax[1, 1].imshow(crop_imge_masked)\n",
"ax[1, 1].axis(\"off\")\n",
"\n",
"ax[3, 0].set_title(\"GT mask\")\n",
"ax[3, 0].imshow(roi_maskRGB)\n",
"ax[3, 0].axis(\"off\")\n",
"\n",
"#crop_imge_masked_gt\n",
"\n",
"ax[3, 1].set_title(\"Result + GT orverlay\")\n",
"ax[3, 1].imshow(crop_imge_masked_gt)\n",
"ax[3, 1].axis(\"off\")\n",
"ax[3, 1].text(0, 1400, f\"Acc.:{iou:.4}\", fontsize = 22, color = 'w',bbox = dict(facecolor = 'red', alpha = 0.5))\n",
"\n",
"_, thresh = cv2.threshold(rgb_mask[..., 1], 0, 255, cv2.THRESH_BINARY)\n",
"kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))\n",
"morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)\n",
"\n",
"morph = cv2.GaussianBlur(morph, (5, 5), 0)\n",
"\n",
"contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
"contours = contours[0] if len(contours) == 2 else contours[1]\n",
"\n",
"roi = max(contours, key=cv2.contourArea)\n",
"x, y, w, h = cv2.boundingRect(roi)\n",
"\n",
"ax[2, 0].set_title(\"Crop Thresholded mask P(x) > 0.5\")\n",
"ax[2, 0].imshow(morph[y : y + h, x : x + w], cmap = \"gray\")\n",
"ax[2, 0].axis(\"off\")\n",
"\n",
"ax[2, 1].set_title(\"Crop Result\")\n",
"ax[2, 1].imshow(cv2.bitwise_and(sample, sample, mask=morph)[y : y + h, x : x + w])\n",
"ax[2, 1].axis(\"off\")\n",
"\n",
"print(frameToFind.shape, pred.shape)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def dice_score(mask1: np.ndarray, mask2: np.ndarray) -> float:\n",
" \"\"\"\n",
" Calculate the DICE (Dice Similarity Coefficient) between two binary masks.\n",
" \n",
" Parameters:\n",
" -----------\n",
" mask1 : np.ndarray\n",
" First binary mask (values must be 0 or 1).\n",
" mask2 : np.ndarray\n",
" Second binary mask (values must be 0 or 1).\n",
"\n",
" Returns:\n",
" --------\n",
" float\n",
" Dice score between mask1 and mask2.\n",
" \"\"\"\n",
" # Ensure the two masks are of the same shape\n",
" assert mask1.shape == mask2.shape, \"Masks must have the same shape\"\n",
" \n",
" # Convert to boolean if necessary (True where mask > 0)\n",
" mask1_bool = mask1.astype(bool)\n",
" mask2_bool = mask2.astype(bool)\n",
"\n",
" # Calculate intersection: number of elements where both masks are 1\n",
" intersection = np.logical_and(mask1_bool, mask2_bool).sum()\n",
"\n",
" # Calculate the sum of each mask\n",
" mask1_sum = mask1_bool.sum()\n",
" mask2_sum = mask2_bool.sum()\n",
"\n",
" # Handle edge case to avoid division by zero\n",
" if mask1_sum + mask2_sum == 0:\n",
" # If both masks are entirely zero, define Dice as 1 or 0 depending on convention\n",
" return 1.0 if (mask1_sum == 0 and mask2_sum == 0) else 0.0\n",
"\n",
" # Calculate the DICE score\n",
" dice = 2.0 * intersection / (mask1_sum + mask2_sum)\n",
" return dice"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import jaccard_score\n",
"\n",
"trainLog = [32, 86, 0.3223]\n",
"pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
"Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
"\n",
"Model.to(device)\n",
"Model.eval()\n",
"avg_IoU = 0.0\n",
"avg_DICE = 0.0\n",
"\n",
"for num_case in range (291, 316):\n",
" pathImg = Path(f\"D:/Users/user/Downloads/dataset/RARP291-340/{num_case}.tiff\")\n",
" pathJson = Path(str(pathImg.absolute()).replace(\".tiff\", \".json\"))\n",
" frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
"\n",
" archorPts = json.load(open(pathJson))\n",
" archorPts = np.array(archorPts[\"shapes\"][0][\"points\"])\n",
"\n",
" h, w, _ = frameToFind.shape\n",
" smood_perimeter = catmull_rom_closed_loop(archorPts, n_points=15)\n",
" roi_mask = create_mask_from_contour(smood_perimeter, (h, w))\n",
"\n",
" frameToFind, roi_mask = remove_Black_Border_mask(frameToFind, roi_mask)\n",
"\n",
" input_h, input_w, _ = frameToFind.shape \n",
"\n",
" frameToFind = valtransform(image=frameToFind)\n",
" frameToFind = frameToFind[\"image\"]\n",
" frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
" frameToFind = frameToFind.to(device)\n",
"\n",
" with torch.no_grad():\n",
" pred = Model(frameToFind)\n",
" pred = torch.sigmoid(pred)\n",
" \n",
" _, _, h, w= frameToFind.shape\n",
"\n",
" pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
" rgb_mask = np.zeros((h, w, 3)).astype(np.uint8)\n",
" rgb_mask[pred[:, :, 0] > 0.5] = [0, 255, 0]\n",
" rgb_mask = cv2.resize(rgb_mask, (input_w, input_h), interpolation=cv2.INTER_CUBIC)\n",
"\n",
" roi_maskRGB = np.zeros_like(rgb_mask)\n",
" roi_maskRGB[roi_mask[:, :] == 1] = [0, 0, 255]\n",
"\n",
" ytrue = roi_maskRGB[..., 2]//255\n",
" ypred = rgb_mask[..., 1]//255\n",
"\n",
" iou = jaccard_score(ytrue, ypred, average=\"micro\")\n",
" dice = dice_score(ytrue, ypred)\n",
" print(f\"{num_case} {iou:.4} {dice:.4}\")\n",
" avg_IoU += iou\n",
" avg_DICE += dice\n",
" \n",
"print(f\"Avg {avg_IoU/25:.4} {avg_DICE/25:.4}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import jaccard_score\n",
"\n",
"trainLog = [30, 97, 0.2978]\n",
"pathCkptFile = Path(f\"./logs_debug/lightning_logs/version_{trainLog[0]}/checkpoints/RARP-epoch={trainLog[1]}-val_loss={trainLog[2]:.4f}.ckpt\")\n",
"Model = RARP_NVB_Model.load_from_checkpoint(pathCkptFile)\n",
"\n",
"Model.to(device)\n",
"Model.eval()\n",
"\n",
"pathImg = Path(\"D:/Users/user/Downloads/dataset/RARP291-340/303.tiff\")\n",
"pathJson = Path(str(pathImg.absolute()).replace(\".tiff\", \".json\"))\n",
"frameToFind = cv2.imread(str(pathImg), cv2.IMREAD_COLOR)\n",
"\n",
"archorPts = json.load(open(pathJson))\n",
"archorPts = np.array(archorPts[\"shapes\"][0][\"points\"])\n",
"\n",
"h, w, _ = frameToFind.shape\n",
"smood_perimeter = catmull_rom_closed_loop(archorPts, n_points=15)\n",
"roi_mask = create_mask_from_contour(smood_perimeter, (h, w))\n",
"\n",
"frameToFind, roi_mask = remove_Black_Border_mask(frameToFind, roi_mask)\n",
"\n",
"sample = frameToFind.copy()\n",
"sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)\n",
"input_h, input_w, _ = sample.shape \n",
"\n",
"frameToFind = valtransform(image=frameToFind)\n",
"frameToFind = frameToFind[\"image\"]\n",
"frameToFind = frameToFind.repeat(1, 1, 1, 1)\n",
"frameToFind = frameToFind.to(device)\n",
"\n",
"with torch.no_grad():\n",
" pred = Model(frameToFind)\n",
" pred = torch.sigmoid(pred)\n",
"\n",
"_, _, h, w= frameToFind.shape\n",
"\n",
"pred = pred[0].cpu().numpy().transpose((1, 2, 0))\n",
"rgb_mask = np.zeros((h, w, 3)).astype(np.uint8)\n",
"rgb_mask[pred[:, :, 0] > 0.5] = [0, 255, 0]\n",
"rgb_mask = cv2.resize(rgb_mask, (input_w, input_h), interpolation=cv2.INTER_CUBIC)\n",
"\n",
"roi_maskRGB = np.zeros_like(rgb_mask)\n",
"roi_maskRGB[roi_mask[:, :] == 1] = [0, 0, 255]\n",
"\n",
"crop_imge_masked = cv2.addWeighted(sample, 1, rgb_mask, 0.3, 0)\n",
"crop_imge_masked_gt = cv2.addWeighted(crop_imge_masked, 1, roi_maskRGB, 0.4, 0)\n",
"\n",
"#crop_imge_masked_gt = cv2.addWeighted(roi_maskRGB, 1, rgb_mask, 0.4, 0)\n",
"\n",
"ytrue = roi_maskRGB[..., 2]//255\n",
"ypred = rgb_mask[..., 1]//255\n",
"\n",
"iou = jaccard_score(ytrue, ypred, average=\"micro\")\n",
"dice = dice_score(ytrue, ypred)\n",
"print(f\"IoU: {iou:.4}\")\n",
"print(f\"DICE: {dice:.4}\")\n",
"\n",
"fig, ax = plt.subplots(1, 1, figsize=(10, 15))\n",
"\n",
"ax.set_title(f\"Result {pathImg.name}\")\n",
"ax.imshow(crop_imge_masked_gt)\n",
"ax.axis(\"off\")\n",
"ax.text(0, 1400, f\"IoU:{iou:.4}; DICE:{dice:.4}\", fontsize = 22, color = 'w',bbox = dict(facecolor = 'red', alpha = 0.5))\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pyRARP",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}