Newer
Older
RARP / Tranfo.ipynb
@delAguila delAguila 27 days ago 75 KB Final Commit.
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import lightning as L\n",
    "\n",
    "modelR50 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)\n",
    "modelEV2 = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import lightning as L\n",
    "import van\n",
    "\n",
    "modelR50 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)\n",
    "modelEV2 = van.van_b2()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelEV2.block1[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelR50.layer4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dummyLoad = torch.rand(1, 3, 224, 224)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize a variable to store the output\n",
    "last_conv_output = None\n",
    "\n",
    "# Define a hook function to capture the output\n",
    "def hook_fn(module, input, output):\n",
    "    global last_conv_output\n",
    "    last_conv_output = output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelEV2.eval()\n",
    "modelEV2.block4[-1].mlp.dwconv.register_forward_hook(hook_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelEV2(dummyLoad).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "last_conv_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelR50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.functional.adaptive_avg_pool2d()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelR50.layer4[-1].conv3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelEV2.eval()\n",
    "salida = modelEV2(dummyLoad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelEV2.forward_features(dummyLoad).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelEV2Sub = torch.nn.Sequential(*list(modelEV2.children())[:-1])\n",
    "modelEV2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "salida.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelR50.eval()\n",
    "subModel = torch.nn.Sequential(*list(modelR50.children())[:-2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subModel(dummyLoad).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for params in modelEV2.parameters():\n",
    "    print(params)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm(params, p=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.linalg.norm(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelR50_Decoder = torch.nn.Sequential(*list(modelR50.children())[:-2])\n",
    "modelEV2_Decoder = modelEV2.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelEV2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "\n",
    "model = timm.create_model(\"davit_small.msft_in1k\", pretrained=True, num_classes=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelEV2.fc2= torch.nn.Linear(12,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "pred = torch.tensor([0.3137, 0.3161, 0.4395, 0.4636, 0.4427, 0.3616, 0.3256, 0.3367])\n",
    "trueVal = torch.tensor([0.1577, 0.8148, 0.4295, 0.4841, 0.5133, 0.2474, 0.6306, 0.3618])\n",
    "\n",
    "PredBach = torch.tensor([[0.3137, 0.3161, 0.4395, 0.4636, 0.4427, 0.3616, 0.3256, 0.3367],\n",
    "        [0.2372, 0.0121, 0.5123, 0.4597, 0.3809, 0.3405, 0.3988, 0.3922],\n",
    "        [0.0436, 0.0425, 0.2698, 0.5095, 0.5610, 0.2571, 0.6514, 0.4645]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = torch.abs(PredBach - PredBach )\n",
    "print(res)\n",
    "res = torch.sum(res, 1)\n",
    "\n",
    "mean = res.mean(1).unsqueeze(1)\n",
    "\n",
    "f = (res <= mean).float()\n",
    "\n",
    "f.split\n",
    "\n",
    "f = f / f.count_nonzero(dim=1).unsqueeze(1)\n",
    "\n",
    "torch.sum(PredBach * f, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def CalcAgreement(x:torch):\n",
    "    deltas = torch.sum(torch.abs(x.unsqueeze(1) - x), 1)\n",
    "    deltasMean = deltas.mean()\n",
    "    \n",
    "    factor = (deltas <= deltasMean).float()\n",
    "    factor = factor / factor.count_nonzero()\n",
    "    \n",
    "    return torch.dot(factor, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.tensor([CalcAgreement(d.squeeze()) for d in PredBach.split(1, 0)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = []\n",
    "for d in PredBach.split(1, 0):\n",
    "    res.append(CalcAgreement(d.squeeze()))\n",
    "\n",
    "torch.tensor(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CalcAgreement(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.abs(pred.unsqueeze(1) - pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.sum(torch.abs(pred.unsqueeze(1) - pred), 1)\n",
    "aMean = a.mean()\n",
    "aMean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "factor = (a <= aMean).float()\n",
    "factor = factor / factor.count_nonzero()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "factor, a, pred, factor * pred, torch.dot(factor, pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kl_diver(a, b):\n",
    "    return torch.sum(a * torch.log(a/b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kl_diver(pred, pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Dkl = torch.nn.KLDivLoss(reduction=\"sum\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Dkl(pred, trueVal)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import Models\n",
    "\n",
    "modelo = Models.RARP_NVB_EfficientNetV2_Deep()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modelo.model.classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RARP_NVB_NN(L.LightningModule):\n",
    "    def __init__(self, input_size:int=4) -> None:\n",
    "        super().__init__()\n",
    "        \n",
    "        self.model = torch.nn.Sequential(\n",
    "            torch.nn.Linear(input_size, 3),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(3, 1)\n",
    "        )\n",
    "        \n",
    "    def forward(self, data):\n",
    "        x = self.model(data)\n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RARP_NVB_ResNet50_CAM(L.LightningModule):\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        \n",
    "        self.model = torchvision.models.resnet50()\n",
    "        tempFC_ft = self.model.fc.in_features \n",
    "        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)\n",
    "        \n",
    "        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])\n",
    "        \n",
    "    def forward(self, data):\n",
    "        featureMap = self.feature_map(data)\n",
    "        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) \n",
    "        Cont_Net = torch.flatten(Cont_Net)\n",
    "        \n",
    "        pred = self.model.fc(Cont_Net)\n",
    "        \n",
    "        return pred, featureMap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any\n",
    "\n",
    "\n",
    "class Modelo_er(L.LightningModule):\n",
    "    def __init__(self, *args: Any, **kwargs: Any) -> None:\n",
    "        super().__init__(*args, **kwargs)\n",
    "               \n",
    "        self.feature_map1 = torch.nn.Sequential(*list(modelR50.children())[:-2])\n",
    "        self.feature_map2 = modelEV2.features\n",
    "        \n",
    "    def forward(self, data):\n",
    "        featureMap1 = self.feature_map1(data)\n",
    "        featureMap2 = self.feature_map2(data)\n",
    "        \n",
    "        Cont_Net1 = torch.nn.functional.adaptive_avg_pool2d(input=featureMap1, output_size=(1, 1)) \n",
    "        Cont_Net2 = torch.nn.functional.adaptive_avg_pool2d(input=featureMap2, output_size=(1, 1)) \n",
    "        \n",
    "        Cont_Net1 = torch.flatten(Cont_Net1)\n",
    "        #Cont_Net2 = torch.flatten(Cont_Net2)\n",
    "        \n",
    "        return Cont_Net1, Cont_Net2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "import defs\n",
    "import Loaders\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "Dataset = Loaders.RARP_DatasetCreator(\n",
    "    \"./DataSet_big_2\",\n",
    "    FoldSeed=505,\n",
    "    createFile=True,\n",
    "    SavePath=\"./DatasetBig2\",\n",
    "    Fold=5,\n",
    "    removeBlackBar=True,\n",
    ")\n",
    "\n",
    "Dataset.CreateFolds()\n",
    "Device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])\n",
    "batchSize = 17\n",
    "Fold = 0\n",
    "rootFile = Dataset.CVS_File.parent.parent/f\"fold_{Fold}\"\n",
    "\n",
    "testtransform =  torch.nn.Sequential(\n",
    "    transforms.Resize(256, antialias=True),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.Normalize(Dataset.mean, Dataset.std)\n",
    ").to(Device)\n",
    "\n",
    "testDataset = torchvision.datasets.DatasetFolder(\n",
    "    str (rootFile/\"test\"),\n",
    "    loader=defs.load_file_tensor,\n",
    "    extensions=\"npy\",\n",
    "    transform=testtransform\n",
    ")\n",
    "\n",
    "Test_DataLoader = DataLoader(\n",
    "    testDataset, \n",
    "    batch_size=batchSize, \n",
    "    num_workers=0, \n",
    "    shuffle=False, \n",
    "    pin_memory=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ModeloObj = Modelo_er()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ModeloObj.to(Device)\n",
    "ModeloObj.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    for data, label in iter(Test_DataLoader):\n",
    "        data = data.float().to(Device)\n",
    "        label = label.to(Device)\n",
    "        pred = ModeloObj(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import Models as M\n",
    "\n",
    "lsita = [M.RARP_NVB_ResNet50(None, M.TypeLossFunction.BCEWithLogits), M.RARP_NVB_ResNet18(None, M.TypeLossFunction.BCEWithLogits)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lsita[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.flatten(pred[1]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "DumpCSV = pd.read_csv(Path(\"./DatasetSmallBalanced_seed_505/dump/dataset.csv\"))\n",
    "Extradata = pd.read_excel(Path(\"./DataSet_smallBalaced/data.xlsx\"))\n",
    "\n",
    "Extradata[\"name\"] =  Extradata[\"列1\"].astype(str) + \".tiff\"\n",
    "Extradata = Extradata.drop(columns=[\"列1\"])\n",
    "\n",
    "DumpCSV[\"raw_name\"] = \"Img0-\" + DumpCSV[\"id\"].astype(str) + \".npy\"\n",
    "DumpCSV = DumpCSV.drop(columns=[\"id\", \"path\", \"mean_1\", \"mean_2\", \"mean_3\", \"std_1\", \"std_2\", \"std_3\"])\n",
    "\n",
    "NewData = pd.merge(Extradata, DumpCSV, on=\"name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Path(\"./DatasetSmallBalanced_seed_505/dump/dataset.csv\").name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Extradata.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DumpCSV.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NewData.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NewData[NewData[\"name\"] == \"51.tiff\"].values.flatten().tolist()[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoFeatureExtractor, VanForImageClassification\n",
    "\n",
    "model = VanForImageClassification.from_pretrained(\"Visual-Attention-Network/van-base\")\n",
    "feature_extractor = AutoFeatureExtractor.from_pretrained(\"Visual-Attention-Network/van-base\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_extractor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from timm.models.layers import DropPath, to_2tuple, trunc_normal_\n",
    "from timm.models.registry import register_model\n",
    "from timm.models.vision_transformer import _cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "checkpoint = torch.hub.load_state_dict_from_url(\n",
    "    url=\"https://huggingface.co/Visual-Attention-Network/VAN-Small-original/resolve/main/van_small_811.pth.tar\", map_location=\"cpu\", check_hash=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "w = torch.tensor([2, 1, 1])\n",
    "\n",
    "w * torch.tensor([2, 2, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "import torch\n",
    "model = timm.create_model(\"levit_384.fb_dist_in1k\", pretrained=True, num_classes=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#model.head = torch.nn.Identity()\n",
    "x = torch.randn(1, 3, 224, 224)\n",
    "\n",
    "res = model(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(list(model.stages.children()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.stages[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import van\n",
    "\n",
    "model = van.van_b2(pretrained = True, num_classes = 0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "\n",
    "#resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')\n",
    "resnet50 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)\n",
    "vit = torchvision.models.vit_b_16()\n",
    "\n",
    "vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "xcit_small_12_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xcit_small_12_p8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vitb8\n",
    "#vit\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "\n",
    "timm.create_model(\"levit_192.fb_dist_in1k\", pretrained=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import Loaders\n",
    "import defs\n",
    "import numpy as np\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "mean, std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])\n",
    "\n",
    "trainDataset = Loaders.RARP_DatasetFolder_ROIExtractor_OnlyROI(\n",
    "                \"DataSet_Kpts_FullSize_seed_505/fold_0/test\",\n",
    "                loader=defs.load_file,\n",
    "                extensions=\"npy\"\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import Loaders\n",
    "import defs\n",
    "import numpy as np\n",
    "import json\n",
    "import cv2\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def _catmull_rom_spline(P0, P1, P2, P3, n_points=20):\n",
    "    points = []\n",
    "    for t in np.linspace(0, 1, n_points):\n",
    "        # Catmull-Rom formula\n",
    "        t2 = t * t\n",
    "        t3 = t2 * t\n",
    "        x = 0.5 * ((2 * P1[0]) +\n",
    "                (-P0[0] + P2[0]) * t +\n",
    "                (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t2 +\n",
    "                (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t3)\n",
    "        \n",
    "        y = 0.5 * ((2 * P1[1]) +\n",
    "                (-P0[1] + P2[1]) * t +\n",
    "                (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t2 +\n",
    "                (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t3)\n",
    "        \n",
    "        points.append((x, y))\n",
    "        \n",
    "    return points\n",
    "\n",
    "def _catmull_rom_closed_loop(points, n_points=20):\n",
    "    spline_points = []\n",
    "    n = len(points)\n",
    "    \n",
    "    for i in range(n):\n",
    "        P0 = points[(i - 1) % n]\n",
    "        P1 = points[i]\n",
    "        P2 = points[(i + 1) % n]\n",
    "        P3 = points[(i + 2) % n]\n",
    "        spline_points += _catmull_rom_spline(P0, P1, P2, P3, n_points)\n",
    "        \n",
    "    return np.array(spline_points)\n",
    "\n",
    "def _create_mask_from_contour(spline_points:np, mask_size = (0, 0)):\n",
    "    smooth_curve_int = np.round(spline_points).astype(np.int32)\n",
    "    mask = np.zeros(mask_size, dtype=np.uint8)\n",
    "    \n",
    "    return cv2.fillPoly(mask, [smooth_curve_int], 1)\n",
    "\n",
    "\n",
    "img = np.load(\"DataSet_Kpts_FullSize_seed_505\\dump\\Img0-4.npy\").astype(float)\n",
    "data = json.load(open(\"DataSet_Kpts_FullSize_seed_505\\json\\Img0-4.json\"))\n",
    "kpts = data[\"shapes\"][0][\"points\"]\n",
    "\n",
    "h, w, _ = img.shape\n",
    "smood_perimeter = _catmull_rom_closed_loop(kpts, n_points=15)\n",
    "roi_mask = _create_mask_from_contour(smood_perimeter, (h, w))\n",
    "roi_mask = cv2.bitwise_and(img, img, mask=roi_mask)\n",
    "x, y, w, h = cv2.boundingRect(np.round(smood_perimeter).astype(np.int32))\n",
    "\n",
    "\n",
    "img = img[...,::-1].copy().astype(np.int32)\n",
    "roi_mask = roi_mask[..., ::-1].copy().astype(np.int32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(15, 8))\n",
    "\n",
    "ax[0].imshow(img)\n",
    "ax[0].axis(\"off\")\n",
    "\n",
    "ax[1].imshow(roi_mask[y : y + h, x : x + w])\n",
    "ax[1].axis(\"off\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "torch.tensor(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def catmull_rom_spline(P0, P1, P2, P3, n_points=20):\n",
    "    \"\"\"Calculate Catmull-Rom points between P1 and P2.\"\"\"\n",
    "    points = []\n",
    "    for t in np.linspace(0, 1, n_points):\n",
    "        # Catmull-Rom formula\n",
    "        t2 = t * t\n",
    "        t3 = t2 * t\n",
    "        x = 0.5 * ((2 * P1[0]) +\n",
    "                   (-P0[0] + P2[0]) * t +\n",
    "                   (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t2 +\n",
    "                   (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t3)\n",
    "        y = 0.5 * ((2 * P1[1]) +\n",
    "                   (-P0[1] + P2[1]) * t +\n",
    "                   (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t2 +\n",
    "                   (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t3)\n",
    "        points.append((x, y))\n",
    "    return points\n",
    "\n",
    "def catmull_rom_closed_loop(points, n_points=20):\n",
    "    \"\"\"Calculate Catmull-Rom spline for a closed loop.\"\"\"\n",
    "    spline_points = []\n",
    "    n = len(points)\n",
    "    for i in range(n):\n",
    "        P0 = points[(i - 1) % n]\n",
    "        P1 = points[i]\n",
    "        P2 = points[(i + 1) % n]\n",
    "        P3 = points[(i + 2) % n]\n",
    "        spline_points += catmull_rom_spline(P0, P1, P2, P3, n_points)\n",
    "    return np.array(spline_points)\n",
    "\n",
    "# Hexagon points\n",
    "hexagon = np.array([\n",
    "[     290.62,      88.884],\n",
    "[     166.75,      107.33],\n",
    "[     45.791,      322.66],\n",
    "[     250.46,      383.89],\n",
    "[     498.39,       343.2],\n",
    "[     437.47,      109.45]\n",
    "])\n",
    "\n",
    "# Generate smooth curve\n",
    "smooth_curve = catmull_rom_closed_loop(hexagon, n_points=50)\n",
    "smooth_curve = np.round(smooth_curve).astype(np.int32)\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(8, 8))\n",
    "plt.plot(hexagon[:, 0], hexagon[:, 1], 'o-', label='Original Hexagon')\n",
    "plt.plot(smooth_curve[:, 0], smooth_curve[:, 1], '-', label='Smoothed Curve')\n",
    "plt.legend()\n",
    "plt.axis('equal')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Given sample points (x values) and target values (t values) from the supplementary material\n",
    "x_samples = np.array([-4, -1.0, 1.8180, 3.0, 4.9])\n",
    "t_values = np.array([9, 0, 7.9413, -20, 52.2])\n",
    "\n",
    "# Construct the design matrix H using the given basis functions h(x) = [1, x, x^2, x^3, x^4]\n",
    "H = np.vstack([\n",
    "    np.ones_like(x_samples),   # h1(x) = 1\n",
    "    x_samples,                 # h2(x) = x\n",
    "    x_samples**2,              # h3(x) = x^2\n",
    "    x_samples**3,              # h4(x) = x^3\n",
    "    x_samples**4               # h5(x) = x^4\n",
    "]).T\n",
    "\n",
    "# Regularization parameter (Weight Decay λ)\n",
    "lambda_val = 0.1  # Example value, can be adjusted\n",
    "\n",
    "# Construct the regularization matrix (Lambda)\n",
    "Lambda = lambda_val * np.eye(H.shape[1])\n",
    "\n",
    "# Compute the optimal weight vector w using Ridge Regression formula\n",
    "w_optimal = np.linalg.inv(H.T @ H + Lambda) @ H.T @ t_values\n",
    "\n",
    "# Display the optimal weights\n",
    "w_optimal\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import cdist\n",
    "\n",
    "# Define RBF function (Gaussian kernel)\n",
    "def rbf(x, centers, sigma=1.0):\n",
    "    return np.exp(-cdist(x.reshape(-1,1), centers.reshape(-1,1))**2 / (2 * sigma**2))\n",
    "\n",
    "# Use the sample points as RBF centers\n",
    "centers = x_samples.copy()\n",
    "\n",
    "# Construct the design matrix H using RBFs\n",
    "sigma = 1.5  # Example sigma value, can be adjusted\n",
    "H_rbf = rbf(x_samples, centers, sigma)\n",
    "\n",
    "# Compute the optimal weights for RBF network\n",
    "w_rbf = np.linalg.inv(H_rbf.T @ H_rbf + Lambda) @ H_rbf.T @ t_values\n",
    "\n",
    "# Display the optimal weights for RBF network\n",
    "w_rbf\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import torch\n",
    "import cv2\n",
    "import numpy as np\n",
    "import Models as M\n",
    "\n",
    "import torchvision.transforms as T\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "def _removeBlackBorder(image):\n",
    "    image = np.array(image)\n",
    "    \n",
    "    copyImg = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV)\n",
    "    h = copyImg[:,:,0]\n",
    "    mask = np.ones(h.shape, dtype=np.uint8) * 255\n",
    "    th = (25, 175)\n",
    "    mask[(h > th[0]) & (h < th[1])] = 0\n",
    "    copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)\n",
    "    resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)\n",
    "        \n",
    "    image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)\n",
    "    _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)\n",
    "    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))\n",
    "    morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)\n",
    "    contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
    "    contours = contours[0] if len(contours) == 2 else contours[1]\n",
    "    bigCont = max(contours, key=cv2.contourArea)\n",
    "    x, y, w, h = cv2.boundingRect(bigCont)\n",
    "    crop = image[y : y + h, x : x + w]\n",
    "    return crop\n",
    "\n",
    "    \n",
    "\n",
    "#RN50ModelToEval = M.RARP_NVB_ResNet50_VAN.load_from_checkpoint(\"./log_X12_VAN_Review/lightning_logs/version_17/checkpoints/RARP-epoch=29.ckpt\")\n",
    "RN50ModelToEval = M.RARP_NVB_DINO_MultiTask.load_from_checkpoint(\"./log_X13_van_DINO/lightning_logs/version_33/checkpoints/RARP-epoch=32.ckpt\")\n",
    "th = 0.6299"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean, std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])\n",
    "transforms = T.Compose([\n",
    "    T.Resize((256,256), antialias=True, interpolation=T.InterpolationMode.BICUBIC),\n",
    "    T.CenterCrop(224),\n",
    "    T.Normalize(mean, std)\n",
    "])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = cv2.imread(str(Path(\"D:/Users/user/Documents/postata/RARP/Clasification/DataSet_big/NVB/187.tiff\")), cv2.IMREAD_COLOR)\n",
    "img = _removeBlackBorder(img)\n",
    "img = torch.Tensor(img)\n",
    "img = img.permute(2, 0, 1).float()\n",
    "\n",
    "originalSize = img.shape[-2:]\n",
    "originalImg = img\n",
    "\n",
    "img = transforms(img)\n",
    "\n",
    "img = img.repeat(1, 1, 1, 1)\n",
    "\n",
    "torch.set_float32_matmul_precision('high')\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "Device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "RN50ModelToEval.to(Device)\n",
    "RN50ModelToEval.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    img = img.to(Device)\n",
    "    \n",
    "    Doutput, _, _ = RN50ModelToEval(img)\n",
    "    \n",
    "    Doutput = Doutput.flatten()\n",
    "    \n",
    "    pred = torch.sigmoid(Doutput)\n",
    "    \n",
    "    print (pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, EigenGradCAM\n",
    "from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget\n",
    "from pytorch_grad_cam.utils.image import show_cam_on_image\n",
    "\n",
    "targetL = [RN50ModelToEval.student.backbone.block4[-1], RN50ModelToEval.teacher_Features.backbone.block4[-1]]\n",
    "\n",
    "for param in RN50ModelToEval.teacher_Features.backbone.parameters():\n",
    "    param.requires_grad = True\n",
    "\n",
    "#CAM = EigenGradCAM(model=RN50ModelToEval, target_layers=targetL)\n",
    "CAM = GradCAMPlusPlus(model=RN50ModelToEval, target_layers=targetL) \n",
    "#CAM = GradCAM(model=RN50ModelToEval, target_layers=targetL)\n",
    "tar = [ClassifierOutputTarget(0)]\n",
    "\n",
    "\n",
    "gi = CAM(img, targets=tar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Denorlalize (img:torch.Tensor, std, mean):\n",
    "    ImgNumpy = img.numpy().transpose((1, 2, 0))\n",
    "    ImgNumpy = np.clip((std * ImgNumpy + mean) / 255, 0, 1)\n",
    "    ImgNumpy = ImgNumpy[...,::-1].copy()\n",
    "    \n",
    "    return ImgNumpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "\n",
    "rgb_img = Denorlalize(img.cpu()[0], std, mean)\n",
    "oriImg = Denorlalize(originalImg.cpu(), [1, 1, 1], [0, 0, 0])\n",
    "\n",
    "grayscale_cam = gi[0]\n",
    "\n",
    "visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)\n",
    "\n",
    "smallH, smallW, _ = rgb_img.shape\n",
    "x, y = (16, 16) #(256 - 224) // 2\n",
    "over = np.ones((256, 256), dtype=np.uint8)\n",
    "layer2 = np.zeros((224, 224), dtype=np.uint8)\n",
    "over[y:y + smallH, x:x + smallW] = layer2\n",
    "over = cv2.resize(over, (originalSize[1], originalSize[0]))\n",
    "over = cv2.bitwise_and(oriImg, oriImg, mask=over)\n",
    "\n",
    "bkMask = cv2.inRange(over, (0,0,0), (0,0,0))\n",
    "contours, _ = cv2.findContours(bkMask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n",
    "\n",
    "xBK, yBK, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea)) \n",
    "\n",
    "layer2 = cv2.resize(visualization / 255, (w, h))\n",
    "\n",
    "over[yBK:yBK + h, xBK:xBK + w] = layer2\n",
    "\n",
    "fig, ax = plt.subplots(1, 2, figsize=(20, 8))\n",
    "\n",
    "ax[0].imshow(oriImg)\n",
    "ax[0].set_title(\"Original Image\")\n",
    "ax[0].axis(\"off\")\n",
    "\n",
    "ax[1].imshow(over)\n",
    "ax[1].set_title(f\"CAM Prediction: {'NVB' if pred.item() > th else 'NO_NVB'}; Conf.: {pred.item():.4f}\")\n",
    "ax[1].axis(\"off\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import torch\n",
    "\n",
    "LabelData =  pd.read_excel(Path(\"./DataSet_Ando_All_no20Crop/MultiLabels.xlsx\"))\n",
    "Dump = pd.read_csv(Path(\"./DataSet_AndoAll20_crop_seed_505/dump/dataset.csv\"))\n",
    "Dump[\"raw_name\"] = \"Img0-\" + Dump[\"id\"].astype(str) + \".npy\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clasiffier = torch.nn.Sequential(\n",
    "            torch.nn.Linear(1024, 128),\n",
    "            torch.nn.Dropout(0.4),\n",
    "            torch.nn.SiLU(True),\n",
    "            \n",
    "            torch.nn.Linear(128, 8),\n",
    "            torch.nn.Dropout(0.2),\n",
    "            torch.nn.SiLU(True),\n",
    "            \n",
    "            torch.nn.Linear(8, 3)\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(clasiffier.children())[-1].out_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clasiffier[-1].out_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LabelData"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DumpCSV = Dump.drop(columns=[\"mean_1\", \"mean_2\", \"mean_3\", \"std_1\", \"std_2\", \"std_3\", \"path\", \"class\", \"label\"])\n",
    "DumpCSV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outPut = pd.merge(LabelData, DumpCSV, on=\"name\", how=\"right\")\n",
    "outPut"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outList = [int(x) for x in str(outPut[outPut[\"raw_name\"] == \"Img0-15.npy\"][\"encode_l1\"].values[0]).split(\"|\")]\n",
    "torch.tensor(outList)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import Models as M\n",
    "import torchvision.transforms as T\n",
    "import cv2\n",
    "import csv\n",
    "\n",
    "\n",
    "DumpCSV = pd.read_csv(Path(\"./DataSet_AndoAll20_crop_seed_505/dump/dataset.csv\"), usecols=[\"id\", \"label\", \"class\", \"name\"])\n",
    "extra_imgCSV = pd.read_csv(Path(\"./extra_imges_fold.csv\"))\n",
    "arrFolds = np.load(Path(\"./DataSet_AndoAll20_crop_seed_505/dump/Folds.npy\"), allow_pickle=True)\n",
    "base_dir = Path(\"./DataSet_crop/\")\n",
    "out_File = Path(\"./outFile_report_LR.csv\")\n",
    "list_Of_Test = [arrFolds[2], arrFolds[5], arrFolds[8], arrFolds[11], arrFolds[14]]\n",
    "list_of_ckpt = [\n",
    "    {\n",
    "        \"model_pth\":Path(\"./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt\"),\n",
    "        \"th\":0.6190\n",
    "    },\n",
    "    {\n",
    "        \"model_pth\":Path(\"./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt\"),\n",
    "        \"th\":0.50\n",
    "    },\n",
    "    {\n",
    "        \"model_pth\":Path(\"./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt\"),\n",
    "        \"th\":0.7917\n",
    "    },\n",
    "    {\n",
    "        \"model_pth\":Path(\"./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt\"),\n",
    "        \"th\":0.4301\n",
    "    },\n",
    "    {\n",
    "        \"model_pth\":Path(\"./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt\"),\n",
    "        \"th\":0.6667\n",
    "    },\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Eval_NVB_precervation():\n",
    "    def __init__(\n",
    "        self, \n",
    "        model_pth:Path, \n",
    "        th:float = 0.5,\n",
    "        mean = [30.38144216, 42.03988769, 97.8896116], \n",
    "        std=[40.63141752, 44.26910074, 50.29294373]\n",
    "    ):\n",
    "        torch.set_float32_matmul_precision('high')\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "        \n",
    "        self.model = M.RARP_NVB_DINO_MultiTask.load_from_checkpoint(model_pth)\n",
    "        self.model.to(self.device)\n",
    "        self.model.eval()\n",
    "        \n",
    "        self.th = th\n",
    "        \n",
    "        self.transforms = T.Compose([\n",
    "            T.Resize((256,256), antialias=True, interpolation=T.InterpolationMode.BICUBIC),\n",
    "            T.CenterCrop(224),\n",
    "            T.Normalize(mean, std)\n",
    "        ])\n",
    "        \n",
    "    def __call__(self, img_pth:Path):\n",
    "        img = cv2.imread(str(img_pth), cv2.IMREAD_COLOR)\n",
    "        #img = _removeBlackBorder(img)\n",
    "        img = torch.tensor(img, device=self.device, dtype=torch.float32)\n",
    "        img = img.permute(2, 0, 1)\n",
    "        img = self.transforms(img)\n",
    "        \n",
    "        img = img.repeat(1, 1, 1, 1)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            pred, _, _ = self.model(img)\n",
    "            pred = torch.sigmoid(pred.flatten())\n",
    "        \n",
    "        return pred.item(), (pred.item() > self.th)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(out_File, \"x\", newline='') as csvfile:\n",
    "    writerOBJ = csv.writer(csvfile)\n",
    "    writerOBJ.writerow([\"fold\", \"th\", \"case\", \"conf.\", \"NVB_Precervation\"])\n",
    "    for fold, test_set in enumerate(list_Of_Test):\n",
    "        print(fold)\n",
    "        predictor = Eval_NVB_precervation(list_of_ckpt[fold][\"model_pth\"], list_of_ckpt[fold][\"th\"])\n",
    "        for i, row in DumpCSV.loc[DumpCSV[\"id\"].isin(test_set)].iterrows():\n",
    "            matckImge = list(base_dir.rglob(row[\"name\"]))\n",
    "            if matckImge:\n",
    "                conf, pred = predictor(matckImge[0])\n",
    "                writerOBJ.writerow([fold, list_of_ckpt[fold][\"th\"], row[\"name\"], conf, pred * 1])\n",
    "            else:\n",
    "                raise Exception (f\"Not Found {row['name']}\")\n",
    "    print(\"Extra Test Image\")\n",
    "    for _, row in extra_imgCSV.iterrows():\n",
    "        if fold != int(row[\"fold\"]):\n",
    "            fold = int(row[\"fold\"])\n",
    "            print(row[\"fold\"])\n",
    "            predictor = Eval_NVB_precervation(list_of_ckpt[fold][\"model_pth\"], list_of_ckpt[fold][\"th\"])\n",
    "        \n",
    "        matckImge = list(base_dir.rglob(row[\"name\"]))\n",
    "        if matckImge:\n",
    "            conf, pred = predictor(matckImge[0])\n",
    "            writerOBJ.writerow([fold, list_of_ckpt[fold][\"th\"], row[\"name\"], conf, pred * 1])\n",
    "        else:\n",
    "                raise Exception (f\"Not Found {row['name']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = [0.25, 0.75]\n",
    "[p > 0.5 for p in t]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Eval_NVB_precervation_RL():\n",
    "    def __init__(\n",
    "        self, \n",
    "        model_pth:Path, \n",
    "        th:float = 0.5,\n",
    "        mean = [30.38144216, 42.03988769, 97.8896116], \n",
    "        std=[40.63141752, 44.26910074, 50.29294373],\n",
    "    ):\n",
    "        torch.set_float32_matmul_precision('high')\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "        \n",
    "        self.model = M.RARP_NVB_DINO_MultiTask.load_from_checkpoint(model_pth)\n",
    "        self.model.to(self.device)\n",
    "        self.model.eval()\n",
    "        \n",
    "        self.th = th\n",
    "        \n",
    "        self.transforms = T.Compose([\n",
    "            T.Resize((256,256), antialias=True, interpolation=T.InterpolationMode.BICUBIC),\n",
    "            T.CenterCrop(224),\n",
    "            T.Normalize(mean, std)\n",
    "        ])\n",
    "        \n",
    "    def _mask_LR(self, image:torch.Tensor, Left:bool = True):\n",
    "        halfImg = image[:, :, :512]\n",
    "        pad_zeros = torch.zeros_like(halfImg)\n",
    "        \n",
    "        if Left:\n",
    "            listImgs = [halfImg, pad_zeros]\n",
    "        else:\n",
    "            halfImg = image[:, :, 512:1024]\n",
    "            listImgs = [pad_zeros, halfImg]\n",
    "            \n",
    "        return torch.cat(listImgs, dim=2)\n",
    "        \n",
    "    def __call__(self, img_pth:Path):\n",
    "        img = cv2.imread(str(img_pth), cv2.IMREAD_COLOR)\n",
    "        #img = _removeBlackBorder(img)\n",
    "        img = torch.tensor(img, device=self.device, dtype=torch.float32)\n",
    "        img = img.permute(2, 0, 1)\n",
    "        img = self.transforms(img)\n",
    "                \n",
    "        LRImages = [self._mask_LR(img, True), self._mask_LR(img, False)]\n",
    "        LRPred = []\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            for _img in LRImages:\n",
    "                _img = _img.repeat(1, 1, 1, 1)\n",
    "                pred, _, _ = self.model(_img)\n",
    "                pred = torch.sigmoid(pred.flatten())\n",
    "                LRPred.append(pred.item())\n",
    "        \n",
    "        return LRPred, [p > self.th for p in LRPred]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(out_File, \"x\", newline='') as csvfile:\n",
    "    writerOBJ = csv.writer(csvfile)\n",
    "    writerOBJ.writerow([\"fold\", \"th\", \"case\", \"conf. L\", \"conf. R\", \"NVB_L\", \"NVB_R\", \"Extra\"])\n",
    "    for fold, test_set in enumerate(list_Of_Test):\n",
    "        print(fold)\n",
    "        predictor = Eval_NVB_precervation_RL(list_of_ckpt[fold][\"model_pth\"], list_of_ckpt[fold][\"th\"])\n",
    "        for i, row in DumpCSV.loc[DumpCSV[\"id\"].isin(test_set)].iterrows():\n",
    "            matckImge = list(base_dir.rglob(row[\"name\"]))\n",
    "            if matckImge:\n",
    "                conf, pred = predictor(matckImge[0])\n",
    "                writerOBJ.writerow([fold, list_of_ckpt[fold][\"th\"], row[\"name\"], conf[0], conf[1], pred[0] * 1, pred[1] * 1, 0])\n",
    "            else:\n",
    "                raise Exception (f\"Not Found {row['name']}\")\n",
    "    print(\"Extra Test Image\")\n",
    "    for _, row in extra_imgCSV.iterrows():\n",
    "        if fold != int(row[\"fold\"]):\n",
    "            fold = int(row[\"fold\"])\n",
    "            print(row[\"fold\"])\n",
    "            predictor = Eval_NVB_precervation_RL(list_of_ckpt[fold][\"model_pth\"], list_of_ckpt[fold][\"th\"])\n",
    "        \n",
    "        matckImge = list(base_dir.rglob(row[\"name\"]))\n",
    "        if matckImge:\n",
    "            conf, pred = predictor(matckImge[0])\n",
    "            writerOBJ.writerow([fold, list_of_ckpt[fold][\"th\"], row[\"name\"], conf[0], conf[1], pred[0] * 1, pred[1] * 1, 1])\n",
    "        else:\n",
    "                raise Exception (f\"Not Found {row['name']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import Models\n",
    "\n",
    "model = Models.RARP_MAE.load_from_checkpoint(\"./log_SSL_X2_MAE/lightning_logs/version_6/checkpoints/MAE-epoch=13-val_acc=0.8719.ckpt\")\n",
    "\n",
    "#model = Models.RARP_NVB_DINO_MultiTask.load_from_checkpoint(\"./RARP-epoch=21.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.encoder.state_dict(), \"van_b2_MAE.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sm = model.student.backbone.state_dict()\n",
    "tm = model.teacher_Features.backbone.state_dict()\n",
    "\n",
    "torch.save(sm, \"Hybrid T-S Model_S.pth\")\n",
    "torch.save(tm, \"Hybrid T-S Model_T.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import cv2\n",
    "\n",
    "root = Path(\"../archive/\")\n",
    "save_dir = Path(\"Dataset_Video/endo/\")\n",
    "\n",
    "for img_endo in root.glob(\"**/*_endo.png\"):\n",
    "    img = cv2.imread(str(img_endo.absolute()), cv2.IMREAD_COLOR)\n",
    "    img = cv2.resize(img, (640, 360), interpolation=cv2.INTER_CUBIC)\n",
    "    \n",
    "    newImge = save_dir / f\"{img_endo.parent.parent.name}_{img_endo.name.replace('png', 'webp')}\"\n",
    "    cv2.imwrite(str(newImge.absolute()), img, [cv2.IMWRITE_WEBP_QUALITY, 100])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import cv2\n",
    "\n",
    "root = Path(\"D:/Users/user/Downloads/dataset/images/real\")\n",
    "save_dir = Path(\"Dataset_Video/RAG/\")\n",
    "\n",
    "for img_endo in root.glob(\"**/*.jpg\"):\n",
    "    img = cv2.imread(str(img_endo.absolute()), cv2.IMREAD_COLOR)\n",
    "    img = cv2.resize(img, (450, 360), interpolation=cv2.INTER_CUBIC)\n",
    "    \n",
    "    newImge = save_dir / f\"{img_endo.name.replace('jpg', 'webp')}\"\n",
    "    cv2.imwrite(str(newImge.absolute()), img, [cv2.IMWRITE_WEBP_QUALITY, 100])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import cv2\n",
    "\n",
    "def resize_image_aspect_ratio(image, target_width=None, target_height=None, interpolation=cv2.INTER_CUBIC):\n",
    "    \"\"\"\n",
    "    Resizes an image while maintaining its aspect ratio.\n",
    "\n",
    "    Args:\n",
    "        image (numpy.ndarray): The input image.\n",
    "        target_width (int, optional): The desired target width. If None, target_height must be provided.\n",
    "        target_height (int, optional): The desired target height. If None, target_width must be provided.\n",
    "        interpolation (int, optional): The interpolation method to use. Defaults to cv2.INTER_AREA.\n",
    "\n",
    "    Returns:\n",
    "        numpy.ndarray: The resized image.\n",
    "    \"\"\"\n",
    "    h, w = image.shape[:2]\n",
    "    dim = None\n",
    "\n",
    "    if target_width is None and target_height is None:\n",
    "        return image  # No target dimensions provided, return original image\n",
    "\n",
    "    if target_width is not None:\n",
    "        r = target_width / float(w)\n",
    "        dim = (target_width, int(h * r))\n",
    "    else:  # target_height is not None\n",
    "        r = target_height / float(h)\n",
    "        dim = (int(w * r), target_height)\n",
    "\n",
    "    resized_image = cv2.resize(image, dim, interpolation=interpolation)\n",
    "    return resized_image\n",
    "\n",
    "root = Path(\"D:/Users/user/Downloads/Research/Dataset_Video/WCE\")\n",
    "save_dir = Path(\"Dataset_Video/WCE_webp\")\n",
    "\n",
    "for img_endo in root.glob(\"**/*.jpg\"):\n",
    "    img = cv2.imread(str(img_endo.absolute()), cv2.IMREAD_COLOR)\n",
    "    img = resize_image_aspect_ratio(img, target_height=360)\n",
    "    #img = cv2.resize(img, (450, 360), interpolation=cv2.INTER_CUBIC)\n",
    "    \n",
    "    newImge = save_dir / f\"{img_endo.name.replace('jpg', 'webp')}\"\n",
    "    cv2.imwrite(str(newImge.absolute()), img, [cv2.IMWRITE_WEBP_QUALITY, 100])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "root = Path(\"../dataset/Merge Videos/files\")\n",
    "\n",
    "for f in root.glob(\"*.txt\"):\n",
    "    win_file = f.parent / f\"win/{f.name.replace('.txt', '_win.txt')}\"\n",
    "    if win_file.exists():\n",
    "        win_file.unlink()\n",
    "    with open(f, \"r\") as lf:\n",
    "        wf = open(win_file, \"x\")\n",
    "        for line in lf:\n",
    "            line = line.replace(\"mnt\", \"/gabor\")\n",
    "            line = line.replace(\"/\", \"\\\\\")\n",
    "            wf.write(line)\n",
    "        wf.close()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "\n",
    "folds = np.load(Path(\"./DataSet_AndoAll20_crop_seed_505/dump/Folds.npy\"), allow_pickle=True) \n",
    "database = pd.read_csv(Path(\"./DataSet_AndoAll20_crop_seed_505/dump/dataset.csv\"), usecols=[\"id\", \"label\", \"class\", \"name\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "name_set_type = [\"train\", \"val\", \"test\"]\n",
    "folds_dataset = []\n",
    "\n",
    "for fold_num, split in enumerate(np.array_split(folds, len(folds)/3)):\n",
    "    folds_dic = {}\n",
    "    folds_dic[\"fold\"] = fold_num\n",
    "    folds_dic[\"splits\"] = {}\n",
    "    for set_type, subset in enumerate(split):\n",
    "        folds_dic[\"splits\"][name_set_type[set_type]] = [str(n).replace(\".tiff\", \"\") for n in database.loc[database[\"id\"].isin(subset)][\"name\"].to_list()]\n",
    "    folds_dataset.append(folds_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inverted = []\n",
    "\n",
    "for f in folds_dataset:\n",
    "    fold_num = f[\"fold\"]\n",
    "    ids = {}\n",
    "    for split_name, id_values in f[\"splits\"].items():\n",
    "        for v in id_values:\n",
    "            ids[str(v)] = split_name\n",
    "    inverted.append({\"fold\": fold_num, \"ids\": ids})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = []\n",
    "for f in inverted:\n",
    "    fold_num = f[\"fold\"]\n",
    "    for i, split in f[\"ids\"].items():\n",
    "        rows.append({\"ids\": i, f\"fold_{fold_num}\": split})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(rows)\n",
    "df = df.groupby(\"ids\").first().reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(\"folds_table.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inverted[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folds_dataset[0][\"splits\"][\"test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[\n",
    "    {\n",
    "        \"fold\":1,\n",
    "        \"splits\":{\n",
    "            \"train\":[1, 5, 8, 20, 21, 50],\n",
    "            \"val\":[30, 10, 2],\n",
    "            \"test\":[3, 22, 15]\n",
    "        }\n",
    "    },\n",
    "    \n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folds_dic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "root = Path(\"../dataset/Merge Videos/files/win\")\n",
    "\n",
    "output = Path(\"../dataset/Merge Videos/bash_files\")\n",
    "\n",
    "if output.exists():\n",
    "    for f in output.glob(\"*.bat\"):\n",
    "        f.unlink()\n",
    "\n",
    "\n",
    "count = 0\n",
    "for i, f in enumerate(root.glob(\"*.txt\")):\n",
    "    if (i % 5 == 0) and (i != 0):\n",
    "        bf.close()\n",
    "        count += 1\n",
    "        bf = open (output / f\"merge_bash_{count}.bat\", \"x\")\n",
    "    elif i == 0:\n",
    "        bf = open (output / f\"merge_bash_{count}.bat\", \"x\")\n",
    "    \n",
    "    bf.write(f\"ffmpeg.exe -hide_banner -f concat -safe 0 -i \\\"{str(f.resolve())}\\\" -c copy \\\"D:\\\\Users\\\\user\\\\Downloads\\\\dataset\\\\Merge Videos\\\\output_videos\\\\RARP_{f.name.split('_')[1]}.mp4\\\"\\n\")\n",
    "    \n",
    "bf.close()\n",
    "    \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "path  = \"C:\\\\Users\\\\user\\\\Desktop\\\\frame_0000118.webp\"\n",
    "\n",
    "img = cv2.imread(path, cv2.IMREAD_COLOR)\n",
    "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "plt.imshow(img)\n",
    "plt.axis('off')  \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path  = \"C:\\\\Users\\\\user\\\\Desktop\\\\frame_0000118.webp\"\n",
    "\n",
    "img = cv2.imread(path, cv2.IMREAD_COLOR)\n",
    "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "\n",
    "x, y, w, h = (139, 0, 360, 360)\n",
    "\n",
    "img = img[y:y+h, x:x+w]\n",
    "\n",
    "img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)\n",
    "\n",
    "plt.imshow(img)\n",
    "plt.axis('off')  \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import decord\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "\n",
    "\n",
    "Device = torch.device(f\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "w, h = (640, 360)\n",
    "vr = decord.VideoReader(\"./Dataset_RARP_video/181/clip_181_30fps.mp4\", ctx=decord.cpu(0), width=w, height=h)\n",
    "idxs = np.linspace(0, len(vr)-1, num=600, dtype=np.int64)\n",
    "frames = vr.get_batch(idxs).asnumpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez_compressed(\"./Dataset_RARP_video/181/raw_181.npz\", frames=frames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video = torch.from_numpy(np.load(\"./Dataset_RARP_video/181/raw_181.npz\")[\"frames\"]).permute(0, 3, 1, 2)\n",
    "video = video[:, [2, 1, 0], :, : ] #RGB2BGR\n",
    "video = video.float()\n",
    "video = video.to(Device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import Loaders as lo\n",
    "traintransformT2 = torch.nn.Sequential(\n",
    "    transforms.RandomErasing(0.2, value=\"random\"),\n",
    "    transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=5),\n",
    "    transforms.GaussianBlur(5),\n",
    "    transforms.RandomHorizontalFlip(0.3),\n",
    "    transforms.CenterCrop((360, 360)),\n",
    "    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),\n",
    "    transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])\n",
    ").to(Device)\n",
    "\n",
    "testtransformT2 = torch.nn.Sequential(\n",
    "    transforms.CenterCrop((360, 360)),\n",
    "    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),\n",
    "    transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])\n",
    ").to(Device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_o = testtransformT2(video)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_o = traintransformT2(video)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "frame_idx = [0, 100, 200, 300, 400, 599] \n",
    "Mean = torch.tensor([30.38144216, 42.03988769, 97.8896116]).view(1,3,1,1)\n",
    "Std = torch.tensor([40.63141752, 44.26910074, 50.29294373]).view(1,3,1,1)\n",
    "\n",
    "frames = video_o[frame_idx].cpu()\n",
    "frames = frames * Std[0] + Mean[0]\n",
    "frames = torch.clamp(frames[:, [2, 1, 0], :, :]/255, 0, 1)\n",
    "print(frames.shape)\n",
    "\n",
    "grid = torchvision.utils.make_grid(frames, nrow=len(frame_idx))\n",
    "\n",
    "plt.figure(figsize=(25,15))\n",
    "plt.imshow(grid.permute(1,2,0).cpu().numpy())\n",
    "plt.axis(\"off\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "root = Path(\"./Dataset_RARP_video/\")\n",
    "\n",
    "database = pd.read_csv((root/\"dataset_videos_folds.csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_rows = pd.DataFrame(columns=database.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, row in database.iterrows():\n",
    "    new_row = {}\n",
    "    for col in database.columns:\n",
    "        new_row[col] = []\n",
    "    for file in sorted(Path(row[\"path\"]).parent.glob(\"*.webp\")):\n",
    "        for col in database.columns:\n",
    "            match(col):\n",
    "                case \"path\":\n",
    "                    val = \"./\" + file.as_posix()\n",
    "                case \"type\":\n",
    "                    val = \"f\"\n",
    "                case _:\n",
    "                    val = row[col]\n",
    "                    \n",
    "            new_row[col].append(val)\n",
    "    new_rows = pd.concat([new_rows, pd.DataFrame(new_row)], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "database_fv = pd.concat([database, new_rows], ignore_index=True).sort_values(by=[\"case\", \"path\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "database_fv.to_csv((root/\"dataset_videos_frames_folds.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "\n",
    "img = cv2.imread(\"./Dataset_RARP_video/181/frame_0004.webp\", cv2.IMREAD_COLOR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "\n",
    "def seconds_to_hms(seconds):\n",
    "    hours = seconds // 3600\n",
    "    minutes = (seconds % 3600) // 60\n",
    "    secs = seconds % 60\n",
    "    return f'{int(hours)}:{int(minutes):02}:{int(secs):02}'\n",
    "\n",
    "def find_ROI_rec (base_img:np, roi_img:np, fx_roi:float=None, fy_roi:float=None):\n",
    "    if fx_roi is not None and fy_roi is not None:\n",
    "        roi_img = cv2.resize(roi_img, None, fx=fx_roi, fy=fy_roi, interpolation=cv2.INTER_CUBIC)\n",
    "    \n",
    "    w, h, _ = roi_img.shape\n",
    "    \n",
    "    assert w < base_img.shape[0] and h < base_img.shape[1], \"base image is smaller than the the template\"\n",
    "    \n",
    "    gray_base = cv2.cvtColor(base_img, cv2.COLOR_BGR2GRAY)\n",
    "    gray_roi = cv2.cvtColor(roi_img, cv2.COLOR_BGR2GRAY)\n",
    "\n",
    "    res = cv2.matchTemplate(gray_base, gray_roi, cv2.TM_CCORR_NORMED)\n",
    "    \n",
    "    _, vals, _, locs = cv2.minMaxLoc(res)\n",
    "    #max Val correlation and loc\n",
    "\n",
    "    top_left = locs\n",
    "    bottom_right = (top_left[0] + w, top_left[1] + h)\n",
    "    \n",
    "    return top_left, bottom_right, res, vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "case_num = 423\n",
    "label_class = \"NVB\"\n",
    "\n",
    "pth_str_base = f\"./DataSet/{label_class}/{case_num}.tiff\"\n",
    "pth_str_ROI = f\"./DataSet_crop/{label_class}/{case_num}.tiff\"\n",
    "\n",
    "base_img = cv2.imread(pth_str_base, cv2.IMREAD_GRAYSCALE)\n",
    "roi_img = cv2.imread(pth_str_ROI, cv2.IMREAD_COLOR)\n",
    "\n",
    "H, W = base_img.shape\n",
    "\n",
    "root = Path(f\"./Dataset_RARP_video/{case_num}\")\n",
    "cap = cv2.VideoCapture(str(root / f\"clip_{case_num}_30fps.mp4\"))\n",
    "total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
    "video_fps = cap.get(cv2.CAP_PROP_FPS)\n",
    "print(f\" Video FPS: {video_fps:.2f}, Total frames: {total_frames}, Video length: {seconds_to_hms(total_frames / video_fps)}\")\n",
    "\n",
    "#read fist frame\n",
    "_, frame_0 = cap.read()\n",
    "cap.set(cv2.CAP_PROP_POS_MSEC, 0)\n",
    "\n",
    "frame_H, frame_W = frame_0.shape[:-1]\n",
    "fx, fy = frame_W / W, frame_H / H \n",
    "\n",
    "#f_rate = (fx + fy) / 2\n",
    "\n",
    "out_video_pth = str(root / f\"clip_{case_num}_30fps.avi\")\n",
    "codec = cv2.VideoWriter_fourcc(*'XVID')\n",
    "out_video = cv2.VideoWriter(out_video_pth, codec, video_fps, (int(1024*fx), int(1024*fx)))\n",
    "\n",
    "font = cv2.FONT_HERSHEY_SIMPLEX\n",
    "font_scale = 2\n",
    "color = (0, 255, 0)  # Green color\n",
    "thickness = 2\n",
    "\n",
    "i = 0\n",
    "while cap.isOpened():\n",
    "    #cap.set(cv2.CAP_PROP_POS_MSEC, i * 1000)\n",
    "    ret, frame = cap.read()\n",
    "    #i += 1\n",
    "    if not ret:\n",
    "        break\n",
    "    \n",
    "    #print (f\"frame size: {(frame_H, frame_W)}, base image size: {(W, H)}, ROI resize ratio: ({fx:.4f}, {fy:.4f})\")\n",
    "    tl, br, _, corr = find_ROI_rec(frame, roi_img, fx_roi=fx, fy_roi=fx)\n",
    "    crop_frame = frame.copy()\n",
    "    crop_frame = crop_frame[tl[1]:br[1], tl[0]:br[0]]\n",
    "    #cv2.rectangle(frame, tl, br, 255, 3)\n",
    "    #cv2.putText(frame, f\"{corr:.4f}\", br, font, font_scale, color, thickness, cv2.LINE_AA)\n",
    "    out_video.write(crop_frame)\n",
    "    #cv2.imwrite(str(root / f\"res_frame_{i}.jpg\"), frame)\n",
    "\n",
    "out_video.release()\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from EfficientViT.classification.model.build import EfficientViT_M5\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class SEAttention(nn.Module):\n",
    "    def __init__(self, in_channels, reduction=16):\n",
    "        super(SEAttention, self).__init__()\n",
    "        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(in_channels, in_channels // reduction, bias=False),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(in_channels // reduction, in_channels, bias=False),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        b, c, _, _ = x.size()\n",
    "        y = self.avg_pool(x).view(b, c)\n",
    "        y = self.fc(y).view(b, c, 1, 1)\n",
    "        return x * y\n",
    "\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, in_size, predict_change=False):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.in_size = in_size\n",
    "        self.predict_change = predict_change\n",
    "\n",
    "        # Initial representation\n",
    "        self.fc = nn.Linear(384*4*4, 7 * 7 * 1024)\n",
    "        self.bn1d = nn.BatchNorm1d(7 * 7 * 1024)\n",
    "        self.gelu = nn.GELU()\n",
    "\n",
    "        # Decoder layers\n",
    "        self.conv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, output_padding=0)\n",
    "        self.bn1 = nn.BatchNorm2d(512)\n",
    "        self.relu1 = nn.ReLU()\n",
    "        #self.conv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, output_padding=0)\n",
    "        self.conv2 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, output_padding=0)\n",
    "        #self.bn2 = nn.BatchNorm2d(256)\n",
    "        self.bn2 = nn.BatchNorm2d(512)\n",
    "        self.relu2 = nn.ReLU()\n",
    "        #self.conv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, output_padding=0)\n",
    "        self.conv3 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, output_padding=0)\n",
    "        #self.bn3 = nn.BatchNorm2d(128)\n",
    "        self.bn3 = nn.BatchNorm2d(256)\n",
    "        self.relu3 = nn.ReLU()\n",
    "        #self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0)\n",
    "        self.conv4 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1, output_padding=0)\n",
    "        #self.bn4 = nn.BatchNorm2d(64)\n",
    "        self.bn4 = nn.BatchNorm2d(64)\n",
    "        self.relu4 = nn.ReLU()\n",
    "        self.conv5 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, output_padding=0)\n",
    "\n",
    "        # Residual blocks with SE attention\n",
    "        self.res2 = nn.Sequential(\n",
    "            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.Sigmoid(),\n",
    "            SEAttention(64),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "\n",
    "        # was 256\n",
    "        self.res1 = nn.Sequential(\n",
    "            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.Sigmoid(),\n",
    "            SEAttention(512),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        if not self.predict_change:\n",
    "            self.sigmoid = nn.Sigmoid()\n",
    "        else:\n",
    "            self.tanh = nn.Tanh()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc(x.reshape(self.in_size, 384*4*4))\n",
    "        x = self.bn1d(x)\n",
    "        x = self.gelu(x)\n",
    "        x = x.view(-1, 1024, 7, 7)\n",
    "        x = self.relu1(self.bn1(self.conv1(x)))\n",
    "        x = self.relu2(self.bn2(self.conv2(x)))\n",
    "        x = self.res1(x) + x\n",
    "        x = self.relu3(self.bn3(self.conv3(x)))\n",
    "        x = self.relu4(self.bn4(self.conv4(x)))\n",
    "        x = self.res2(x) + x\n",
    "        x = self.conv5(x)\n",
    "        if not self.predict_change:\n",
    "            x = self.sigmoid(x)\n",
    "        else:\n",
    "            x = self.tanh(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class EfficientViTAutoEncoder(nn.Module):\n",
    "    def __init__(self, in_size, predict_change=False):\n",
    "        super(EfficientViTAutoEncoder, self).__init__()\n",
    "        self.predict_change = predict_change\n",
    "        self.decoder = Decoder(in_size, predict_change)\n",
    "        self.evit = EfficientViT_M5(pretrained='efficientvit_m5')\n",
    "        # remove the classification head\n",
    "        self.evit = torch.nn.Sequential(*list(self.evit.children())[:-1])\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.evit(x)\n",
    "        decoded = self.decoder.forward(out)\n",
    "        return decoded\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evit_gsvit = EfficientViTAutoEncoder(16)\n",
    "evit_gsvit.load_state_dict(torch.load(\"./EfficientViT/GSViT.pkl\", map_location=\"cpu\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(evit_gsvit.evit.state_dict(), \"./EfficientViT/EfficientViT_GSViT.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from EfficientViT.classification.model.build import EfficientViT_M5\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class EfficientViT_GSViT(nn.Module):\n",
    "    def __init__(self, pre_trained:str=\"./EfficientViT/EfficientViT_GSViT.pth\"):\n",
    "        super().__init__()\n",
    "        \n",
    "        ## \n",
    "        #\n",
    "        #  A implementation base on the original papaer and repo : https://github.com/SamuelSchmidgall/GSViT.git\n",
    "        #\n",
    "        ##\n",
    "        \n",
    "        self.evit = EfficientViT_M5(pretrained='efficientvit_m5')\n",
    "        # remove the classification head\n",
    "        self.evit = torch.nn.Sequential(*list(self.evit.children())[:-1])\n",
    "        \n",
    "        if len(pre_trained) > 0:\n",
    "            self.evit.load_state_dict(torch.load(pre_trained, map_location=\"cpu\"))\n",
    "            \n",
    "    def forward(self, x):\n",
    "        out = self.evit(x)\n",
    "        return out\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\user\\anaconda3\\envs\\pyRARP\\Lib\\site-packages\\albumentations\\__init__.py:28: UserWarning: A new version of Albumentations is available: '2.0.8' (you have '2.0.4'). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n",
      "  check_for_updates()\n"
     ]
    }
   ],
   "source": [
    "from EfficientViT.GSViT import EfficientViT_GSViT\n",
    "import torch\n",
    "import torchvision\n",
    "import pandas as pd\n",
    "import Loaders\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "\n",
    "Device = torch.device(f\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "FOLD = 0\n",
    "\n",
    "BATCH_SIZE = 2\n",
    "KEY_FRAME = False\n",
    "WIN_LENGTH = 72\n",
    "NUM_WIN = 10\n",
    "WORKERS = 0\n",
    "\n",
    "df = pd.read_csv(\"../dataset/Dataset_RARP_video/dataset_videos_folds.csv\")\n",
    "\n",
    "test_set = df.loc[df[f\"Fold_{FOLD}\"] == \"test\"].sort_values(by=[\"label\", \"case\"]).to_dict(orient=\"records\")\n",
    "\n",
    "testVal_transform = torch.nn.Sequential(\n",
    "    transforms.CenterCrop(300),\n",
    "    transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),\n",
    "    #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    ").to(Device)\n",
    "    \n",
    "test_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(\n",
    "    test_set, \n",
    "    train_mode=False,\n",
    "    num_windows=NUM_WIN,\n",
    "    window_length=WIN_LENGTH, \n",
    "    transform=testVal_transform,\n",
    "    key_frames=KEY_FRAME,\n",
    "    Fold_index=FOLD,\n",
    ")\n",
    "\n",
    "test_loader1 = DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=False,\n",
    "    pin_memory=True,\n",
    "    num_workers=WORKERS,\n",
    "    persistent_workers=WORKERS>0\n",
    ")\n",
    "\n",
    "model = EfficientViT_GSViT(\"./EfficientViT/EfficientViT_GSViT.pth\")\n",
    "model = model.to(Device)\n",
    "model.eval()\n",
    "\n",
    "features = []\n",
    "meta = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    for batch in test_loader1:\n",
    "        winds, _, _, info = batch\n",
    "        \n",
    "        mid = winds[:, :, winds.size(2)//2]\n",
    "        mid = mid.reshape(-1, *mid.shape[2:])\n",
    "        \n",
    "        f = model(mid.to(Device))\n",
    "        f = torch.nn.functional.normalize(f, dim=1)\n",
    "        \n",
    "        features.append(f.cpu())\n",
    "        meta.append(info)\n",
    "\n",
    "X = torch.cat(features, dim=0).numpy()\n",
    "\n",
    "np.save\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 384])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "features[0].shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pyRARP",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}