{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "265c6583",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import csv\n",
"\n",
"root = Path(\"/mnt/Data/Urology/RARP/RARP movie 181-555/\")\n",
"csvFile = Path(\"dataset.csv\")\n",
"imgRoot = Path(\"DataSet/\")\n",
"\n",
"if csvFile.exists():\n",
" csvFile.unlink()\n",
" \n",
"with open(csvFile, \"x\", newline='') as file:\n",
" writerCSV = csv.writer(file)\n",
" writerCSV.writerow([\"case\", \"video_path\", \"image_path\", \"best_frame_seg\", \"error\", \"sim\"])\n",
" \n",
" for r in sorted(list(root.iterdir())):\n",
" caseId = r.name.split(\"_\")[0].replace(\"RARP\", \"\")\n",
" pathImg = list(imgRoot.glob(f\"**/{caseId}.tiff\"))\n",
" pathImg = pathImg[0].absolute() if len(pathImg) != 0 else \"\"\n",
" writerCSV.writerow([caseId, r, pathImg, 0, 0, 0])\n",
" \n",
"\n",
" file.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e09b6bc",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"csvFile = Path(\"dataset.csv\")\n",
"\n",
"ds = pd.read_csv(csvFile)\n",
"\n",
"\n",
"\n",
"splits_sub_sets = 3\n",
"\n",
"splis = np.array_split(ds, splits_sub_sets)\n",
"\n",
"for i, sub_dataset in enumerate(splis):\n",
" sub_dataset.to_csv(f\"sub_dataset_{i}.csv\", index=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27c540fc",
"metadata": {},
"outputs": [],
"source": [
"splis[2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d189105",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"np.random.seed(2023)\n",
"\n",
"angles = [-90, 0, 90, 180]\n",
"batch = 16\n",
"labels_angels = torch.from_numpy(LabelEncoder().fit_transform(angles))\n",
"len(labels_angels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e96b5cf",
"metadata": {},
"outputs": [],
"source": [
"angles[np.random.randint(len(angles))]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "629a6eb2",
"metadata": {},
"outputs": [],
"source": [
"torch.tensor([labels_angels[np.random.randint(len(labels_angels))] for _ in range(batch)])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87857835",
"metadata": {},
"outputs": [],
"source": [
"temp = torch.tensor([labels_angels[i % len(labels_angels)] for i in range(batch)])\n",
"temp"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad6bbac8",
"metadata": {},
"outputs": [],
"source": [
"for i in temp:\n",
" print (angles[i])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec193e0e",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pathlib import Path\n",
"import numpy as np\n",
"\n",
"file_csv = pd.read_csv(Path(\"dataset.csv\"))\n",
"copy_csv_file = file_csv.copy()\n",
"\n",
"for i, row in file_csv.iterrows():\n",
" if pd.isnull(row[\"image_path\"]):\n",
" print(f\"case {row['case']} skipped\")\n",
" continue\n",
" \n",
" for videos in sorted(list(Path(row[\"video_path\"]).glob(\"**/*.[mM][pP]4\"))):\n",
" print(videos, row[\"image_path\"])\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "281de82b",
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"from pathlib import Path\n",
"import numpy as np\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def centerCrop(img:np, size:tuple = None, pct_size:float = 0.3):\n",
" if size is None:\n",
" size = int(img.shape[0] * (1 - pct_size)), int(img.shape[1] * (1 - pct_size)//1.5)\n",
" x, y = (img.shape[1] - size[1]) // 2, (img.shape[0] - size[0]) // 2\n",
" \n",
" return img[y:y+size[0], x:x+size[1]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "312a6fb9",
"metadata": {},
"outputs": [],
"source": [
"img = cv2.imread(str(Path(\"./Dataset_video/555/frame_0021014.webp\")))\n",
"\n",
"frame = centerCrop(img, (300, 420))\n",
"print (img.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0198fb4",
"metadata": {},
"outputs": [],
"source": [
"\n",
"fig, ax = plt.subplots(1, 3, figsize=(15, 8))\n",
"\n",
"ax[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))\n",
"ax[0].set_title(f\"Original Image \")\n",
"ax[0].axis(\"off\")\n",
"\n",
"ax[1].imshow(cv2.cvtColor(cv2.resize(frame, (300, 300), interpolation=cv2.INTER_CUBIC), cv2.COLOR_BGR2RGB))\n",
"ax[1].set_title(f\"Center crop (300 x 420) and resize (300 x 300)\")\n",
"ax[1].axis(\"off\")\n",
"\n",
"ax[2].imshow(cv2.cvtColor(centerCrop(img, (300, 300)), cv2.COLOR_BGR2RGB))\n",
"ax[2].set_title(f\"Center crop (300 x 300)\")\n",
"ax[2].axis(\"off\")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75ab41a8",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import torchvision\n",
"from torchvision import transforms\n",
"import numpy as np\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"import Loaders\n",
"import Models_from_server\n",
"import defs\n",
"import van\n",
"\n",
"def ViewImgDINO(dataset, std, mean):\n",
" _, axis = plt.subplots(4, 9, figsize=(25, 25))\n",
" for i in range(4):\n",
" random_index = np.random.randint(0, len(dataset.targets))\n",
" imgCrops, label = dataset[random_index]\n",
" for j, img in enumerate(imgCrops):\n",
" img = img.numpy().transpose((1, 2, 0))\n",
" img = np.clip((std * img + mean), 0, 1)\n",
" \n",
" axis[i][j].imshow(img)\n",
" axis[i][j].set_title(f\"Label:{label}\")\n",
" axis[i][j].axis(\"off\")\n",
" \n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"def setup_seed(seed):\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
" np.random.seed(seed)\n",
" torch.backends.cudnn.deterministic = True\n",
" \n",
"setup_seed(2023)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"batchSize = 16 #17 #8, 32\n",
"\n",
"Mean = [0.485, 0.456, 0.406]\n",
"Std = [0.229, 0.224, 0.225]\n",
"\n",
"angles = [-90, 0, 90, 180]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e8d9465",
"metadata": {},
"outputs": [],
"source": [
"default_transform = torch.nn.Sequential(\n",
" transforms.CenterCrop(300),\n",
" transforms.Normalize(Mean, Std)\n",
").to(device)\n",
"\n",
"TrainDINOTransforms = Loaders.RARP_DINO_Augmentation(\n",
" Init_Resize=(300, 420),\n",
" GloblaCropsScale = (0.4, 1),\n",
" LocalCropsScale = (0.05, 0.4),\n",
" NumLocalCrops = 6,\n",
" Size = 224, \n",
" device = device,\n",
" mean = Mean,\n",
" std = Std,\n",
" Tranform_0 = default_transform\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e317ee06",
"metadata": {},
"outputs": [],
"source": [
"TrainDINOTransforms = Loaders.RARP_MAE_Augmentation(\n",
" Init_Resize=(300, 420),\n",
" GloblaCropsScale = (0.4, 1),\n",
" Size = 224, \n",
" device = device,\n",
" mean = Mean,\n",
" std = Std\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e54a03cd",
"metadata": {},
"outputs": [],
"source": [
"trainDataset = torchvision.datasets.DatasetFolder(\n",
" \"./Dataset_video/\",\n",
" loader=defs.load_Img_RBG_tensor_norm,\n",
" extensions=\"webp\",\n",
" transform=TrainDINOTransforms\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90087946",
"metadata": {},
"outputs": [],
"source": [
"def mask_patches(imgs, patch_size=16, mask_ratio=0.25):\n",
" imgs = imgs.repeat(1, 1, 1, 1)\n",
" B, C, H, W = imgs.shape\n",
" ph, pw = patch_size, patch_size\n",
" gh, gw = H // ph, W // pw\n",
" \n",
" mask:torch.Tensor = (torch.rand(B, gh * gw, device=imgs.device) >= mask_ratio)\n",
" mask = mask.reshape(B, 1, gh, gw) # [B,1,gh,gw]\n",
" \n",
" # expand mask to full image\n",
" mask = mask.repeat_interleave(ph, dim=2) # [B,1,gh*ph,gw]\n",
" mask = mask.repeat_interleave(pw, dim=3) # [B,1,gh*ph,gw*pw] == [B,1,H,W]\n",
" \n",
" mask = mask.expand(-1, C, -1, -1)\n",
" \n",
" imgs_masked = imgs * mask\n",
" return imgs_masked"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8eb2fb3",
"metadata": {},
"outputs": [],
"source": [
"def ViewImgDINO_axu(dataset, std, mean):\n",
" _, axis = plt.subplots(4, 10, figsize=(25, 15))\n",
" for i in range(4):\n",
" random_index = np.random.randint(0, len(dataset.targets))\n",
" imgCrops, label = dataset[random_index]\n",
" for j, img in enumerate(imgCrops):\n",
" img = img.numpy().transpose((1, 2, 0))\n",
" img = np.clip((std * img + mean), 0, 1)\n",
" \n",
" axis[i][j].imshow(img)\n",
" axis[i][j].set_title(f\"Label:{label}\")\n",
" axis[i][j].axis(\"off\")\n",
" \n",
" nextAngle = angles[np.random.randint(len(angles))] \n",
" rot_img = torchvision.transforms.functional.rotate(imgCrops[0], nextAngle)\n",
" \n",
" rot_img = rot_img.numpy().transpose((1, 2, 0))\n",
" rot_img = np.clip((std * rot_img + mean), 0, 1)\n",
" \n",
" axis[i][9].imshow(rot_img)\n",
" axis[i][9].set_title(f\"Label:{nextAngle}\")\n",
" axis[i][9].axis(\"off\")\n",
" \n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cb095f1",
"metadata": {},
"outputs": [],
"source": [
"setup_seed(2023)\n",
"_, axis = plt.subplots(4, 2, figsize=(25, 25))\n",
"for i in range(4):\n",
" random_index = np.random.randint(0, len(trainDataset.targets))\n",
" imgCrops, label = trainDataset[random_index]\n",
" for j, img in enumerate(imgCrops):\n",
" img = mask_patches(img)[0] if j == 1 else img\n",
" img = img.numpy().transpose((1, 2, 0))\n",
" img = np.clip((Std * img + Mean), 0, 1)\n",
" \n",
" axis[i][j].imshow(img)\n",
" axis[i][j].set_title(f\"Label:{label}\")\n",
" axis[i][j].axis(\"off\")\n",
" \n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7830e80d",
"metadata": {},
"outputs": [],
"source": [
"ViewImgDINO_axu(trainDataset, Std, Mean)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6b4e579",
"metadata": {},
"outputs": [],
"source": [
"model = van.van_b1(pretrained=False, num_classes = 0)\n",
"model.load_state_dict(torch.load(\"./van_b1_student.pth\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18b3ac57",
"metadata": {},
"outputs": [],
"source": [
"img = defs.load_Img_RBG_tensor_norm(\"./DataSet_Ando_All_no20/NVB/371.tiff\")\n",
"\n",
"dtransform = torch.nn.Sequential(\n",
" transforms.Resize(360, interpolation=transforms.InterpolationMode.BICUBIC),\n",
" transforms.CenterCrop(300),\n",
" transforms.Normalize(Mean, Std)\n",
").to(device)\n",
"\n",
"print(img.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e234c0a2",
"metadata": {},
"outputs": [],
"source": [
"img = dtransform(img)\n",
"print (img.shape)\n",
"\n",
"_, axis = plt.subplots(1, 1, figsize=(25, 25))\n",
"\n",
"img = img.numpy().transpose((1, 2, 0))\n",
"img = np.clip((Std * img + Mean), 0, 1)\n",
"\n",
"axis.imshow(img)\n",
"axis.axis(\"off\")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "044c7214",
"metadata": {},
"outputs": [],
"source": [
"img = dtransform(img)\n",
"img = img.repeat(1, 1, 1, 1)\n",
"print(img.shape)\n",
"\n",
"img = img.to(device)\n",
"model = model.to(device)\n",
"model.eval()\n",
"\n",
"with torch.no_grad():\n",
" res = model(img)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6db462f",
"metadata": {},
"outputs": [],
"source": [
"res.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0289bc82",
"metadata": {},
"outputs": [],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"model = SentenceTransformer(\"Vinit3241/clinical_trials_all-MiniLM-L6-v2\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d911372",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd1cb5e1",
"metadata": {},
"outputs": [],
"source": [
"import torchvision\n",
"import torch\n",
"\n",
"model = torchvision.models.convnext_small(weights=torchvision.models.ConvNeXt_Small_Weights.DEFAULT)\n",
"model.classifier[-1] = torch.nn.Identity()\n",
"\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb2cc158",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import Models_from_server\n",
"\n",
"#model = Models.RARP_Encoder_DINO.load_from_checkpoint(\"/home/diego/research/log_SSL_X1_DINO/lightning_logs/version_5/checkpoints/DINO_S-epoch=90-val_silhouette_student=0.5181.ckpt\")\n",
"model = Models_from_server.RARP_Encoder_DINO_AUX_task.load_from_checkpoint(\"/home/diego/research/log_SSL_X1_DINO/lightning_logs/version_10/checkpoints/DINO_S-epoch=54-val_acc=0.9999.ckpt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54a11749",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.student.backbone.state_dict(), \"convnext_s_student_aux_54.pth\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01310d47",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.teacher.backbone.state_dict(), \"convnext_s_teacher_aux_54.pth\")"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "35239c6a",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import cv2\n",
"from pathlib import Path\n",
"import math\n",
"import numpy as np\n",
"import csv\n",
"\n",
"csvFile = Path(\"frames_from_videos.csv\")\n",
"\n",
"if csvFile.exists():\n",
" csvFile.unlink()\n",
"\n",
"def read_frame_by_seg(pth_video:str, timestamp:float):\n",
" cap = cv2.VideoCapture(pth_video)\n",
" #fps = cap.get(cv2.CAP_PROP_FPS)\n",
" #frameNum = int(fps * timestamp)\n",
" #cap.set(cv2.CAP_PROP_POS_FRAMES, frameNum)\n",
" \n",
" cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)\n",
" _, frame = cap.read()\n",
" frame = cv2.resize(frame, (640, 360), interpolation=cv2.INTER_CUBIC)\n",
" \n",
" return frame\n",
"\n",
"db_file = pd.read_csv(\"dataset_RARP_videos_folds.csv\")\n",
"\n",
"root = Path(\"./frames_from_videos_Merge/\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0dbdb88",
"metadata": {},
"outputs": [],
"source": [
"def rewrite_file(path:str):\n",
" f_path = Path(path)\n",
" if f_path.exists():\n",
" f_path.unlink()\n",
" return f_path\n",
"\n",
"output_root = Path(\"./script_files/\")\n",
"output_root.mkdir(exist_ok=True)\n",
"\n",
"csvFile = Path(\"./script_files/videos_files_merge.csv\")\n",
"\n",
"if csvFile.exists():\n",
" csvFile.unlink()\n",
" \n",
"with open(csvFile, \"x\", newline='') as file:\n",
" writerCSV = csv.writer(file)\n",
" writerCSV.writerow([\"case\", \"merge_file\", \"num_videos\"])\n",
"\n",
" prev_case = 0\n",
" c = 0\n",
" f_path = rewrite_file(output_root / f\"case_{prev_case}.txt\")\n",
" f = open(f_path, \"x\")\n",
" for i, row in db_file.iterrows():\n",
" try:\n",
" if prev_case != row['case']:\n",
" f.close()\n",
" writerCSV.writerow([prev_case, f\"case_{prev_case}.txt\", c])\n",
" prev_case = row['case']\n",
" c = 0\n",
" \n",
" f_path = rewrite_file(output_root / f\"case_{prev_case}.txt\")\n",
" f = open(f_path, \"x\")\n",
" f.write(f\"file '{row['video_path']}'\\n\")\n",
" else:\n",
" c += 1\n",
" f.write(f\"file '{row['video_path']}'\\n\")\n",
" \n",
" except Exception as e:\n",
" print (f\"row id{i}, case:{row['case']}\",e) "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8e9d683a",
"metadata": {},
"outputs": [],
"source": [
"prev_case = 0\n",
"c = 0\n",
"for i, row in db_file.iterrows():\n",
" try:\n",
" if prev_case != row['case']:\n",
" prev_case = row['case']\n",
" c = 0\n",
" case_folder:Path = root/str(row['case'])\n",
" case_folder.mkdir(exist_ok=True)\n",
" \n",
" img = cv2.imread(row['image_path'], cv2.IMREAD_COLOR)\n",
" img = cv2.resize(img, (640, 360), interpolation=cv2.INTER_CUBIC)\n",
" \n",
" frame = read_frame_by_seg(row['video_path'], row['best_frame_seg'])\n",
" \n",
" cv2.imwrite(str(case_folder/f\"source_img_{row['case']}.webp\"), img, [cv2.IMWRITE_WEBP_QUALITY, 101])\n",
" cv2.imwrite(str(case_folder/f\"frame_img_{row['case']}_{c}.webp\"), frame, [cv2.IMWRITE_WEBP_QUALITY, 101])\n",
" else:\n",
" c += 1\n",
" print(f\"case: {row['case']} counter: {c}\")\n",
" frame = read_frame_by_seg(row['video_path'], row['best_frame_seg'])\n",
" \n",
" cv2.imwrite(str(case_folder/f\"frame_img_{row['case']}_{c}.webp\"), frame, [cv2.IMWRITE_WEBP_QUALITY, 101])\n",
" except Exception as e:\n",
" print (f\"row id{i}, case:{row['case']}\",e) \n",
" \n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "743bfe0b",
"metadata": {},
"outputs": [],
"source": [
"def make_grid(tiles, cols=5, bg_color=(0, 0, 0)):\n",
" \"\"\"\n",
" Stack tiles into a single grid.\n",
" cols fixed number of columns (rows computed automatically)\n",
" bg_color RGB tuple for padding when #tiles isn't a multiple of cols\n",
" \"\"\"\n",
" n = len(tiles)\n",
" rows = math.ceil(n / cols)\n",
" h, w, c = tiles[0].shape\n",
"\n",
" # pad with blank tiles so we have rows*cols exactly\n",
" pad = rows * cols - n\n",
" if pad:\n",
" blank = np.full((h, w, c), bg_color, dtype=np.uint8)\n",
" tiles.extend([blank] * pad)\n",
"\n",
" # build the grid row-by-row\n",
" row_imgs = [\n",
" np.hstack(tiles[r * cols:(r + 1) * cols]) # horizontal stack\n",
" for r in range(rows)\n",
" ]\n",
" return np.vstack(row_imgs)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2d85d0f6",
"metadata": {},
"outputs": [],
"source": [
"def resize_image_aspect_ratio(image, target_width=None, target_height=None, interpolation=cv2.INTER_CUBIC):\n",
" \"\"\"\n",
" Resizes an image while maintaining its aspect ratio.\n",
"\n",
" Args:\n",
" image (numpy.ndarray): The input image.\n",
" target_width (int, optional): The desired target width. If None, target_height must be provided.\n",
" target_height (int, optional): The desired target height. If None, target_width must be provided.\n",
" interpolation (int, optional): The interpolation method to use. Defaults to cv2.INTER_AREA.\n",
"\n",
" Returns:\n",
" numpy.ndarray: The resized image.\n",
" \"\"\"\n",
" h, w = image.shape[:2]\n",
" dim = None\n",
"\n",
" if target_width is None and target_height is None:\n",
" return image # No target dimensions provided, return original image\n",
"\n",
" if target_width is not None:\n",
" r = target_width / float(w)\n",
" dim = (target_width, int(h * r))\n",
" else: # target_height is not None\n",
" r = target_height / float(h)\n",
" dim = (int(w * r), target_height)\n",
"\n",
" resized_image = cv2.resize(image, dim, interpolation=interpolation)\n",
" return resized_image"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a5dd54ba",
"metadata": {},
"outputs": [],
"source": [
"if csvFile.exists():\n",
" csvFile.unlink()\n",
" \n",
"with open(csvFile, \"x\", newline='') as file:\n",
" writerCSV = csv.writer(file)\n",
" writerCSV.writerow([\"case\", \"image_path\", \"sim\"])\n",
"\n",
" prev_case = 0\n",
" for i, row in db_file.iterrows():\n",
" caseId = str(row['case'])\n",
" if prev_case == caseId:\n",
" continue\n",
" \n",
" prev_case = caseId\n",
" case_folder = root/caseId\n",
" source = cv2.imread(str(case_folder.absolute() / f\"source_img_{caseId}.webp\"), cv2.IMREAD_COLOR)\n",
" listFramesPth = []\n",
" simFrames = [] \n",
" listFrames = []\n",
" for img in sorted(case_folder.glob(\"frame_*[0-9].webp\")):\n",
" listFramesPth.append(str(img.absolute()))\n",
" listFrames.append(cv2.imread(str(img.absolute()), cv2.IMREAD_COLOR))\n",
" lenList = len(listFrames)\n",
" if lenList == 1 or lenList == 2:\n",
" grid_cols = 1\n",
" elif lenList >= 3 and lenList <= 6:\n",
" grid_cols = 2\n",
" elif lenList > 6:\n",
" grid_cols = 4 \n",
" \n",
" for img in listFrames:\n",
" #img = cv2.resize(img, )\n",
" simFrames.append(cv2.matchTemplate(source, img, cv2.TM_CCORR_NORMED)[0][0])\n",
" \n",
" writerCSV.writerow([caseId, listFramesPth[np.argmax(simFrames)], simFrames[np.argmax(simFrames)]])\n",
" \n",
" grid = make_grid(listFrames, grid_cols) \n",
" cv2.imwrite(case_folder/f\"grid_frames_{caseId}.webp\", grid, [cv2.IMWRITE_WEBP_QUALITY, 101])\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1df0c5cb",
"metadata": {},
"outputs": [],
"source": [
"for i, row in db_file.iterrows():\n",
" caseId = str(row['case'])\n",
" case_folder = root/caseId\n",
" source_img = cv2.imread(str(case_folder/f\"source_img_{caseId}.webp\"), cv2.IMREAD_COLOR)\n",
" grid_img = cv2.imread(str(case_folder/f\"grid_frames_{caseId}.webp\"), cv2.IMREAD_COLOR)\n",
" \n",
" try: \n",
" comp_grid = np.hstack([source_img, grid_img])\n",
" except: \n",
" source_img = resize_image_aspect_ratio(source_img, target_height=grid_img.shape[0])\n",
" comp_grid = np.hstack([source_img, grid_img])\n",
" \n",
" \n",
" cv2.imwrite(case_folder/f\"final_grid_{caseId}.webp\", comp_grid, [cv2.IMWRITE_WEBP_QUALITY, 101])\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4c64d81",
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"from pathlib import Path\n",
"\n",
"root = Path(\"./Dataset_RARP_video/\")\n",
"\n",
"for video_path in sorted(root.glob(\"**/*.mp4\")):\n",
" print (video_path)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5171e913",
"metadata": {},
"outputs": [],
"source": [
"import ffmpeg\n",
"\n",
"wind_frame = 15 * 60 #15min in seg\n",
"\n",
"for i, row in db_file.iterrows():\n",
" caseId = str(row['case'])\n",
" case_folder = root/caseId\n",
" source_video = row['video_path']\n",
" b = row['best_frame_seg'] - 300 # 5 min prev the query frame\n",
" a = b - wind_frame\n",
" output_file_str = str((case_folder/f\"clip_{caseId}.mp4\").absolute())\n",
" \n",
" try:\n",
" (\n",
" ffmpeg\n",
" .input(source_video)\n",
" .trim(start=a, end=b)\n",
" .setpts('PTS-STARTPTS') \n",
" .output(output_file_str)\n",
" .run(overwrite_output=True)\n",
" )\n",
" print (f\"Video successfully croped A:{a} to B:{b}\")\n",
" except ffmpeg.Error as e:\n",
" print (f\"Error: {e.stderr.decode()}\")\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "827dc53a",
"metadata": {},
"outputs": [],
"source": [
"# 20 seg text videos @ 30 fps\n",
"\n",
"import shlex\n",
"from pathlib import Path\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"\n",
"\n",
"db_file = pd.read_csv(\"o_dataset_extra.csv\")\n",
"wind_frame = 20 # 15 minutes in seconds\n",
"\n",
"# --- config ---\n",
"out_root = Path(\"frames_from_videos_clips_out\") # output folder root\n",
"script_copy = Path(\"run_clips_copy.sh\") # time-trim only, no re-encode\n",
"script_1fps = Path(\"run_clips_30fps.sh\") # time-trim + 30 FPS H.264\n",
"\n",
"out_root.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def fmt_ts(sec: float) -> str:\n",
" \"\"\"Format seconds -> HH:MM:SS.mmm (ffmpeg-friendly).\"\"\"\n",
" if sec < 0: sec = 0.0\n",
" h = int(sec // 3600)\n",
" m = int((sec % 3600) // 60)\n",
" s = sec % 60\n",
" return f\"{h:02d}:{m:02d}:{s:06.3f}\"\n",
"\n",
"def q(path: str) -> str:\n",
" \"\"\"Shell-escape safely for bash.\"\"\"\n",
" return shlex.quote(str(path))\n",
"\n",
"copy_lines = []\n",
"fps1_lines = []\n",
"\n",
"for _, row in db_file.iterrows():\n",
" \n",
" case_id = str(row[\"case\"])\n",
" src = Path(row[\"video_path\"])\n",
" case_dir = out_root / case_id\n",
" case_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
" # Your A/B logic\n",
" key_frame = float(row[\"best_frame_seg\"])\n",
" a = max(0.0, key_frame - wind_frame / 2)\n",
" b = key_frame + wind_frame / 2\n",
" out_copy = case_dir / f\"clip_{case_id}.mp4\" # time-trim only\n",
" out_1fps = case_dir / f\"clip_{case_id}_30fps.mp4\" # 1-FPS archive\n",
"\n",
" # Variant 1: fastest, no re-encode (keeps original codec; only trims time)\n",
" # Note: using input-side -ss for speed, -to for end time; good enough for most cases.\n",
" cmd_copy = (\n",
" f\"ffmpeg -hide_banner -loglevel error \"\n",
" f\"-ss {fmt_ts(a)} -to {fmt_ts(b)} -i {q(src)} \"\n",
" f\"-c copy {q(out_copy)}\"\n",
" )\n",
" copy_lines.append(cmd_copy)\n",
"\n",
" # Variant 2: archive at 1 FPS (time-trim + re-encode to H.264 near-lossless)\n",
" # Use libx264 + CRF 18; tweak preset as needed (faster/slower).\n",
" cmd_1fps = (\n",
" f\"ffmpeg -hide_banner -loglevel error \"\n",
" f\"-ss {fmt_ts(a)} -to {fmt_ts(b)} -i {q(src)} \"\n",
" f\"-vf fps=30 \"\n",
" f\"-an -c:v libx264 -crf 18 -preset slow -pix_fmt yuv420p -movflags +faststart \"\n",
" f\"{q(out_1fps)}\"\n",
" )\n",
" fps1_lines.append(cmd_1fps)\n",
"\n",
"# Write scripts\n",
"script_copy.write_text(\n",
" \"#!/usr/bin/env bash\\nset -euo pipefail\\n\\n\" + \"\\n\".join(copy_lines) + \"\\n\",\n",
" encoding=\"utf-8\"\n",
")\n",
"script_1fps.write_text(\n",
" \"#!/usr/bin/env bash\\nset -euo pipefail\\n\\n\" + \"\\n\".join(fps1_lines) + \"\\n\",\n",
" encoding=\"utf-8\"\n",
")\n",
"\n",
"print(f\"Generated:\\n {script_copy}\\n {script_1fps}\\n\")\n",
"print(\"Make executable: chmod +x run_clips_copy.sh run_clips_1fps.sh\")\n",
"print(\"Run sequentially: ./run_clips_copy.sh\")\n",
"print(\"Run in parallel: cat run_clips_copy.sh | parallel -j 6\") # needs GNU parallel"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6d2c86e2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generated:\n",
" run_clips_copy.sh\n",
" run_clips_1fps.sh\n",
"\n",
"Make executable: chmod +x run_clips_copy.sh run_clips_1fps.sh\n",
"Run sequentially: ./run_clips_copy.sh\n",
"Run in parallel: cat run_clips_copy.sh | parallel -j 6\n"
]
}
],
"source": [
"import shlex\n",
"from pathlib import Path\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"\n",
"\n",
"db_file = pd.read_csv(\"dataset_RARP_videos_folds.csv\")\n",
"\n",
"\n",
"wind_frame = 20 * 60 # 15 minutes in seconds\n",
"\n",
"# --- config ---\n",
"out_root = Path(\"frames_from_videos_clips_out\") # output folder root\n",
"script_copy = Path(\"run_clips_copy.sh\") # time-trim only, no re-encode\n",
"script_1fps = Path(\"run_clips_1fps.sh\") # time-trim + 1 FPS H.264\n",
"\n",
"out_root.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def fmt_ts(sec: float) -> str:\n",
" \"\"\"Format seconds -> HH:MM:SS.mmm (ffmpeg-friendly).\"\"\"\n",
" if sec < 0: sec = 0.0\n",
" h = int(sec // 3600)\n",
" m = int((sec % 3600) // 60)\n",
" s = sec % 60\n",
" return f\"{h:02d}:{m:02d}:{s:06.3f}\"\n",
"\n",
"def q(path: str) -> str:\n",
" \"\"\"Shell-escape safely for bash.\"\"\"\n",
" return shlex.quote(str(path))\n",
"\n",
"copy_lines = []\n",
"fps1_lines = []\n",
"\n",
"for _, row in db_file.iterrows():\n",
" case_id = str(row[\"case\"])\n",
" src = Path(row[\"video_path\"])\n",
" case_dir = out_root / case_id\n",
" case_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
" # Your A/B logic\n",
" b = float(row[\"best_frame_seg\"]) - 300.0 # 5 min before the query frame\n",
" a = max(0.0, b - wind_frame) # 15-min window ending at b\n",
" out_copy = case_dir / f\"clip_{case_id}.mp4\" # time-trim only\n",
" out_1fps = case_dir / f\"clip_{case_id}_1fps.mp4\" # 1-FPS archive\n",
"\n",
" # Variant 1: fastest, no re-encode (keeps original codec; only trims time)\n",
" # Note: using input-side -ss for speed, -to for end time; good enough for most cases.\n",
" cmd_copy = (\n",
" f\"ffmpeg -hide_banner \"\n",
" f\"-ss {fmt_ts(a)} -to {fmt_ts(b)} -i {q(src)} \"\n",
" f\"-c copy {q(out_copy)}\"\n",
" )\n",
" copy_lines.append(cmd_copy)\n",
"\n",
" # Variant 2: archive at 1 FPS (time-trim + re-encode to H.264 near-lossless)\n",
" # Use libx264 + CRF 18; tweak preset as needed (faster/slower).\n",
" cmd_1fps = (\n",
" f\"ffmpeg -hide_banner \"\n",
" f\"-ss {fmt_ts(a)} -to {fmt_ts(b)} -i {q(src)} \"\n",
" f\"-vf \\\"fps=1,scale=640:360:force_original_aspect_ratio=decrease,pad=640:360:(ow-iw)/2:(oh-ih)/2\\\" \"\n",
" f\"-an -c:v libx264 -crf 18 -preset slow -pix_fmt yuv420p -movflags +faststart \"\n",
" f\"{q(out_1fps)}\"\n",
" )\n",
" fps1_lines.append(cmd_1fps)\n",
"\n",
"# Write scripts\n",
"script_copy.write_text(\n",
" \"#!/usr/bin/env bash\\nset -euo pipefail\\n\\n\" + \"\\n\".join(copy_lines) + \"\\n\",\n",
" encoding=\"utf-8\"\n",
")\n",
"script_1fps.write_text(\n",
" \"#!/usr/bin/env bash\\nset -euo pipefail\\n\\n\" + \"\\n\".join(fps1_lines) + \"\\n\",\n",
" encoding=\"utf-8\"\n",
")\n",
"\n",
"print(f\"Generated:\\n {script_copy}\\n {script_1fps}\\n\")\n",
"print(\"Make executable: chmod +x run_clips_copy.sh run_clips_1fps.sh\")\n",
"print(\"Run sequentially: ./run_clips_copy.sh\")\n",
"print(\"Run in parallel: cat run_clips_copy.sh | parallel -j 6\") # needs GNU parallel"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d0c3501",
"metadata": {},
"outputs": [],
"source": [
"import shutil\n",
"from pathlib import Path\n",
"import pandas as pd\n",
"\n",
"merge_file_dir = Path(\"./script_files/files\")\n",
"merge_file_dir.mkdir(exist_ok=True)\n",
"\n",
"db_file_merge = pd.read_csv(\"./script_files/videos_files_merge.csv\")\n",
"\n",
"for i, row in db_file_merge.iterrows():\n",
" if int(row[\"num_videos\"]) > 0:\n",
" shutil.move(merge_file_dir.parent/row[\"merge_file\"], merge_file_dir/f\"case_{row['case']}.txt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc14e4c6",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import pandas as pd\n",
"\n",
"root = Path(\"./Dataset_RARP_video/\")\n",
"dataset = pd.read_csv(\"./Dataset_RARP_video/dataset_videos_folds.csv\")\n",
"videos_list = []\n",
"for n in sorted(root.glob(\"**/*.mp4\")):\n",
" videos_list.append(n.parent.name)\n",
"\n",
"dataset_list = [str(n) for n in dataset[\"case\"].to_list()]\n",
"\n",
"missing = list (set(dataset_list) - set(videos_list))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3bdca274",
"metadata": {},
"outputs": [],
"source": [
"missing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f623018",
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"\n",
"def seconds_to_hms(seconds):\n",
" hours = seconds // 3600\n",
" minutes = (seconds % 3600) // 60\n",
" secs = seconds % 60\n",
" return f'{int(hours)}:{int(minutes):02}:{int(secs):02}'\n",
"\n",
"def find_ROI_rec (base_img:np, roi_img:np, fx_roi:float=None, fy_roi:float=None):\n",
" if fx_roi is not None and fy_roi is not None:\n",
" roi_img = cv2.resize(roi_img, None, fx=fx_roi, fy=fy_roi, interpolation=cv2.INTER_CUBIC)\n",
" \n",
" w, h, _ = roi_img.shape\n",
" \n",
" assert w < base_img.shape[0] and h < base_img.shape[1], \"base image is smaller than the the template\"\n",
" \n",
" gray_base = cv2.cvtColor(base_img, cv2.COLOR_BGR2GRAY)\n",
" gray_roi = cv2.cvtColor(roi_img, cv2.COLOR_BGR2GRAY)\n",
"\n",
" res = cv2.matchTemplate(gray_base, gray_roi, cv2.TM_CCORR_NORMED)\n",
" \n",
" _, vals, _, locs = cv2.minMaxLoc(res)\n",
" #max Val correlation and loc\n",
"\n",
" top_left = locs\n",
" bottom_right = (top_left[0] + w, top_left[1] + h)\n",
" \n",
" return top_left, bottom_right, res, vals\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "caa1d1c4",
"metadata": {},
"outputs": [],
"source": [
"\n",
"pth_str_base = \"./DataSet/NVB/423.tiff\"\n",
"pth_str_ROI = \"./DataSet_crop/NVB/423.tiff\"\n",
"\n",
"base_img = cv2.imread(pth_str_base, cv2.IMREAD_COLOR)\n",
"roi_img = cv2.imread(pth_str_ROI, cv2.IMREAD_COLOR)\n",
"\n",
"base_img = cv2.resize(base_img, (1920, 1080), interpolation=cv2.INTER_CUBIC)\n",
"\n",
"top_left, bottom_right, res, _ = find_ROI_rec(base_img, roi_img, fx_roi=0.6, fy_roi=0.6)\n",
"#top_left, bottom_right, res, _ = find_ROI_rec(base_img, roi_img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41e10527",
"metadata": {},
"outputs": [],
"source": [
"crop_frame = base_img.copy()\n",
"cv2.rectangle(base_img, top_left, bottom_right, 255, 3)\n",
"\n",
"fig, ax = plt.subplots(4, 1, figsize=(25, 15))\n",
"\n",
"ax[0].imshow(cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB))\n",
"ax[0].set_title(\"Result\")\n",
"ax[0].axis(\"off\")\n",
"\n",
"ax[1].imshow(cv2.cvtColor(roi_img, cv2.COLOR_BGR2RGB))\n",
"ax[1].set_title(\"ROI\")\n",
"ax[1].axis(\"off\")\n",
"\n",
"ax[2].imshow(res)\n",
"ax[2].set_title(\"Cross Correlation Map\")\n",
"ax[2].axis(\"off\")\n",
"\n",
"ax[3].imshow(cv2.cvtColor(crop_frame, cv2.COLOR_BGR2RGB))\n",
"ax[3].set_title(\"Cross Correlation Map\")\n",
"ax[3].axis(\"off\")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "759f2ad6",
"metadata": {},
"outputs": [],
"source": [
"case_num = 347\n",
"label_class = \"NOT_NVB\"\n",
"\n",
"pth_str_base = f\"./DataSet/{label_class}/{case_num}.tiff\"\n",
"pth_str_ROI = f\"./DataSet_crop/{label_class}/{case_num}.tiff\"\n",
"\n",
"base_img = cv2.imread(pth_str_base, cv2.IMREAD_GRAYSCALE)\n",
"roi_img = cv2.imread(pth_str_ROI, cv2.IMREAD_COLOR)\n",
"\n",
"H, W = base_img.shape\n",
"\n",
"root = Path(f\"./Dataset_RARP_video/{case_num}\")\n",
"cap = cv2.VideoCapture(str(root / f\"clip_{case_num}_30fps.mp4\"))\n",
"total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
"video_fps = cap.get(cv2.CAP_PROP_FPS)\n",
"print(f\" Video FPS: {video_fps:.2f}, Total frames: {total_frames}, Video length: {seconds_to_hms(total_frames / video_fps)}\")\n",
"\n",
"#read fist frame\n",
"_, frame_0 = cap.read()\n",
"cap.set(cv2.CAP_PROP_POS_MSEC, 0)\n",
"\n",
"frame_H, frame_W = frame_0.shape[:-1]\n",
"fx, fy = frame_W / W, frame_H / H \n",
"\n",
"f_rate = (fx + fy) / 2\n",
"\n",
"out_video_pth = str(root / f\"clip_{case_num}_30fps.avi\")\n",
"codec = cv2.VideoWriter_fourcc(*'XVID')\n",
"out_video = cv2.VideoWriter(out_video_pth, codec, video_fps, (frame_W, frame_H))\n",
"\n",
"font = cv2.FONT_HERSHEY_SIMPLEX\n",
"font_scale = 2\n",
"color = (0, 255, 0) # Green color\n",
"thickness = 2\n",
"\n",
"i = 0\n",
"while cap.isOpened():\n",
" #cap.set(cv2.CAP_PROP_POS_MSEC, i * 1000)\n",
" ret, frame = cap.read()\n",
" #i += 1\n",
" if not ret:\n",
" break\n",
" \n",
" #print (f\"frame size: {(frame_H, frame_W)}, base image size: {(W, H)}, ROI resize ratio: ({fx:.4f}, {fy:.4f})\")\n",
" tl, br, _, corr = find_ROI_rec(frame, roi_img, fx_roi=fx, fy_roi=fx)\n",
" \n",
" cv2.rectangle(frame, tl, br, 255, 3)\n",
" cv2.putText(frame, f\"{corr:.4f}\", br, font, font_scale, color, thickness, cv2.LINE_AA)\n",
" out_video.write(frame)\n",
" #cv2.imwrite(str(root / f\"res_frame_{i}.jpg\"), frame)\n",
"\n",
"out_video.release()\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a26e5329",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pathlib import Path\n",
"\n",
"dataset_list = Path(\"./script_files/videos_files_merge.csv\")\n",
"root = Path(\"/mnt/Data/Urology/RARP/Merge_Videos_RARP\")\n",
"\n",
"db = pd.read_csv(dataset_list)\n",
"db_mergerd = db.loc[db[\"num_videos\"] > 0]\n",
"\n",
"list_new = []"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3c44c66",
"metadata": {},
"outputs": [],
"source": [
"for i, row in db_mergerd.iterrows():\n",
" video = root / f\"RARP_{row['case']}.mp4\"\n",
" if not video.exists():\n",
" print(\"Not exists:\", str(video))\n",
" else:\n",
" list_new.append([row['case'], str(video)])\n",
"\n",
"df = pd.DataFrame(list_new, columns=[\"case\", \"path_mergerd_video\"])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60c4a629",
"metadata": {},
"outputs": [],
"source": [
"df.to_csv(Path(\"./script_files/videos_files_merge_paths.csv\"), index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd3a6744",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"#video = Path(\"/mnt/Data/Urology/RARP/RARP movie 181-555/RARP519_20211223\")\n",
"video = Path(\"/mnt/Data/Urology/RARP/Merge_Videos_RARP/RARP_520.mp4\")\n",
"\n",
"video.is_file()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acd9fc32",
"metadata": {},
"outputs": [],
"source": [
"for ele in video.glob(\"*.[mM][pP]4\"):\n",
" print (ele)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# ==== 5-fold results ====\n",
"gru = {\n",
" \"J\": [0.5591, 0.1171, 0.5000, 0.4189, 0.1225],\n",
" \"Acc\": [0.6809, 0.6391, 0.6755, 0.7817, 0.6402],\n",
" \"Precision\":[0.7292, 0.5899, 0.6289, 0.7911, 0.5818],\n",
" \"Recall\": [0.5660, 0.8977, 0.8561, 0.7513, 0.9974],\n",
" \"F1\": [0.6373, 0.7120, 0.7251, 0.7707, 0.7349],\n",
" \"AUROC\": [0.7174, 0.6334, 0.6741, 0.8177, 0.6499],\n",
"}\n",
"\n",
"tcn = {\n",
" \"J\": [0.4506, 0.1176, 0.5000, 0.4209, 0.9545],\n",
" \"Acc\": [0.7170, 0.6341, 0.7020, 0.7829, 0.6508],\n",
" \"Precision\":[0.7278, 0.5837, 0.6538, 0.7885, 0.7415],\n",
" \"Recall\": [0.6846, 0.9182, 0.8586, 0.7593, 0.4630],\n",
" \"F1\": [0.7056, 0.7137, 0.7424, 0.7736, 0.5700],\n",
" \"AUROC\": [0.7413, 0.6154, 0.6827, 0.8251, 0.6766],\n",
"}\n",
"\n",
"metrics = [\"J\", \"Acc\", \"Precision\", \"Recall\", \"F1\", \"AUROC\"]\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=True)\n",
"\n",
"# ---- Left: RN50+GRU ----\n",
"ax = axes[0]\n",
"ax.boxplot([gru[m] for m in metrics],\n",
" tick_labels=metrics,\n",
" showmeans=True)\n",
"ax.set_title(\"RN50 + GRU\")\n",
"ax.set_ylabel(\"Metric value\")\n",
"ax.grid(True, axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
"ax.tick_params(axis='x', rotation=45)\n",
"\n",
"# ---- Right: RN50+TCN ----\n",
"ax = axes[1]\n",
"ax.boxplot([tcn[m] for m in metrics],\n",
" tick_labels=metrics,\n",
" showmeans=True)\n",
"ax.set_title(\"RN50 + TCN\")\n",
"ax.grid(True, axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
"ax.tick_params(axis='x', rotation=45)\n",
"\n",
"fig.suptitle(\"5-fold Cross-Validation Metrics\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d883eb47",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# ==== 5-fold results ====\n",
"gru = {\n",
" \"J\": [0.6061,0.1014,0.764,0.6315,0.9574],\n",
" \"Acc\": [0.6996,0.6696,0.6919,0.8101,0.6587],\n",
" \"Precision\":[0.7517,0.618,0.7159,0.8367,0.7655],\n",
" \"Recall\": [0.5876,0.8772,0.6364,0.7593,0.4577],\n",
" \"F1\": [0.6596,0.7252,0.6738,0.7961,0.5728],\n",
" \"AUROC\": [0.7322,0.6166,0.6808,0.8291,0.6583],\n",
"}\n",
"\n",
"tcn = {\n",
" \"J\": [0.5552,0.188,0.3152,0.6319,0.1065],\n",
" \"Acc\": [0.6769,0.629,0.7071,0.7713,0.619],\n",
" \"Precision\":[0.7074,0.5832,0.704,0.8018,0.5697],\n",
" \"Recall\": [0.593,0.8875,0.7146,0.7063,0.9735],\n",
" \"F1\": [0.6452,0.7039,0.7093,0.7511,0.7188],\n",
" \"AUROC\": [0.7187,0.6308,0.7001,0.8255,0.601],\n",
"}\n",
"\n",
"metrics = [\"J\", \"Acc\", \"Precision\", \"Recall\", \"F1\", \"AUROC\"]\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=True)\n",
"\n",
"# ---- Left: RN50+GRU ----\n",
"ax = axes[0]\n",
"ax.boxplot([gru[m] for m in metrics],\n",
" tick_labels=metrics,\n",
" showmeans=True)\n",
"ax.set_title(\"RN18 + GRU\")\n",
"ax.set_ylabel(\"Metric value\")\n",
"ax.grid(True, axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
"ax.tick_params(axis='x', rotation=45)\n",
"\n",
"# ---- Right: RN50+TCN ----\n",
"ax = axes[1]\n",
"ax.boxplot([tcn[m] for m in metrics],\n",
" tick_labels=metrics,\n",
" showmeans=True)\n",
"ax.set_title(\"RN18 + TCN\")\n",
"ax.grid(True, axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
"ax.tick_params(axis='x', rotation=45)\n",
"\n",
"fig.suptitle(\"5-fold Cross-Validation Metrics\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20ac439c",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# ==== 5-fold results ====\n",
"gru = {\n",
" \"J\": [0.3129,0.1405,0.6162,0.4348,0.7944],\n",
" \"Acc\": [0.6929,0.676,0.7134,0.7687,0.668],\n",
" \"Precision\":[0.6556,0.6164,0.6833,0.7918,0.7068],\n",
" \"Recall\": [0.8005,0.9207,0.7955,0.7143,0.5741],\n",
" \"F1\": [0.7209,0.7385,0.7351,0.751,0.6336],\n",
" \"AUROC\": [0.7292,0.6598,0.7075,0.7899,0.6923],\n",
"}\n",
"\n",
"tcn = {\n",
" \"J\": [0.4599,0.2373,0.4899,0.471,0.9453],\n",
" \"Acc\": [0.7076,0.648,0.697,0.7674,0.6349],\n",
" \"Precision\":[0.6751,0.6048,0.6548,0.7705,0.7107],\n",
" \"Recall\": [0.7898,0.8414,0.8333,0.746,0.455],\n",
" \"F1\": [0.728,0.7037,0.7333,0.7581,0.5548],\n",
" \"AUROC\": [0.7565,0.6419,0.7141,0.8224,0.6689],\n",
"}\n",
"\n",
"metrics = [\"J\", \"Acc\", \"Precision\", \"Recall\", \"F1\", \"AUROC\"]\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=True)\n",
"\n",
"# ---- Left: RN50+GRU ----\n",
"ax = axes[0]\n",
"ax.boxplot([gru[m] for m in metrics],\n",
" tick_labels=metrics,\n",
" showmeans=True)\n",
"ax.set_title(\"VAN_b2 + GRU\")\n",
"ax.set_ylabel(\"Metric value\")\n",
"ax.grid(True, axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
"ax.tick_params(axis='x', rotation=45)\n",
"\n",
"# ---- Right: RN50+TCN ----\n",
"ax = axes[1]\n",
"ax.boxplot([tcn[m] for m in metrics],\n",
" tick_labels=metrics,\n",
" showmeans=True)\n",
"ax.set_title(\"VAN_b2 + TCN\")\n",
"ax.grid(True, axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
"ax.tick_params(axis='x', rotation=45)\n",
"\n",
"fig.suptitle(\"5-fold Cross-Validation Metrics\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e0f23f3d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_1825081/1377775816.py:25: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" ckpt = torch.load(\"EfficientViT/EfficientViT_GSViT.pth\", map_location=\"cpu\")\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from collections import OrderedDict\n",
"from EfficientViT.classification.model.build import EfficientViT_M5\n",
"\n",
"def remap_sequential_keys(sd, index_to_name):\n",
" out = OrderedDict()\n",
" for k, v in sd.items():\n",
" if k.startswith(\"module.\"):\n",
" k = k[len(\"module.\"):]\n",
" parts = k.split(\".\")\n",
" if parts[0].isdigit():\n",
" idx = int(parts[0])\n",
" if idx not in index_to_name:\n",
" continue # drop anything you don't want\n",
" parts[0] = index_to_name[idx]\n",
" k = \".\".join(parts)\n",
" out[k] = v\n",
" return out\n",
"\n",
"evit = EfficientViT_M5(pretrained=\"efficientvit_m5\")\n",
"\n",
"evit.head = nn.Identity()\n",
"\n",
"ckpt = torch.load(\"EfficientViT/EfficientViT_GSViT.pth\", map_location=\"cpu\")\n",
"sd = ckpt.get(\"state_dict\", ckpt)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "706736c0",
"metadata": {},
"outputs": [],
"source": [
"index_to_name = {k:named_C[0] for k, named_C in enumerate(evit.named_children())}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a47f0a48",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{0: 'patch_embed', 1: 'blocks1', 2: 'blocks2', 3: 'blocks3', 4: 'head'}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index_to_name"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "142aa44c",
"metadata": {},
"outputs": [],
"source": [
"sd2 = remap_sequential_keys(sd, index_to_name)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "34f5148b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Missing: []\n",
"Unexpected: []\n"
]
}
],
"source": [
"miss, unexp = evit.load_state_dict(sd2, strict=False)\n",
"print(\"Missing:\", miss)\n",
"print(\"Unexpected:\", unexp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad5d63d7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}