{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "c7b9e30c",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"from torchvision import transforms\n",
"from transformers import AutoTokenizer, AutoModel\n",
"import van\n",
"import lightning as L\n",
"from lightning.pytorch import seed_everything\n",
"import lightning.pytorch.callbacks as callbk\n",
"from lightning.pytorch.loggers import TensorBoardLogger\n",
"import Models as M\n",
"from pathlib import Path\n",
"import pandas as pd\n",
"import Loaders\n",
"import numpy as np\n",
"import torchmetrics\n",
"import defs\n",
"\n",
"\n",
"LLM = \"emilyalsentzer/Bio_ClinicalBERT\"\n",
"\n",
"PROMPT = [(\n",
" \"Post-prostatectomy robotic laparoscopic view of the pelvic surgical bed in a {age}-year-old patient \"\n",
" \"(BMI {BMI}, PSA {PSA} ng/mL) with clinical stage {cT}, pathologic stage {pT} and Gleason score {GS}; \"\n",
" \"prostate size {prostate_size} mm; operating time was {surgery_time} min (console time {console_time} min); \"\n",
" \"blood loss {blood_loss} mL. Neurovascular bundle preserved: {NVB}\"\n",
")]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6dcc3aca",
"metadata": {},
"outputs": [],
"source": [
"def setup_seed(seed):\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
" np.random.seed(seed)\n",
" seed_everything(seed, workers=True)\n",
" torch.backends.cudnn.deterministic = True\n",
"\n",
"Dataset = Loaders.RARP_DatasetCreator(\n",
" \"./DataSet_Ando_All_no20Crop\",\n",
" FoldSeed=505,\n",
" createFile=True,\n",
" SavePath=\"./DataSet_AndoAll20_crop\",\n",
" Fold=5,\n",
" removeBlackBar=False,\n",
")\n",
"\n",
"Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])\n",
"\n",
"Dataset.CreateFolds()\n",
" \n",
"setup_seed(2023)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"DumpCSV = pd.read_csv(Dataset.CVS_File)\n",
"Extradata = pd.read_csv(Path(\"./DataSet_Ando_All_no20Crop/data_source_prompt.csv\"))\n",
"\n",
"DumpCSV[\"raw_name\"] = \"Img0-\" + DumpCSV[\"id\"].astype(str) + \".npy\"\n",
"DumpCSV = DumpCSV.drop(columns=[\"id\", \"path\", \"mean_1\", \"mean_2\", \"mean_3\", \"std_1\", \"std_2\", \"std_3\"])\n",
"\n",
"NewData = pd.merge(Extradata, DumpCSV, on=\"name\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c806faed",
"metadata": {},
"outputs": [],
"source": [
"NewData[NewData[\"raw_name\"] == \"Img0-33.npy\"].fillna('').to_dict(\"records\")[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5a13167",
"metadata": {},
"outputs": [],
"source": [
"class RARP_DatasetFolder_LLM(torchvision.datasets.DatasetFolder):\n",
" def __init__(self, \n",
" root: str, \n",
" loader, \n",
" Extra_Data: pd.DataFrame,\n",
" tokenizer = None, \n",
" extensions = None, \n",
" transform = None, \n",
" target_transform = None, \n",
" is_valid_file = None\n",
" ) -> None:\n",
" super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)\n",
" self.Extra_Data = Extra_Data\n",
" self.tokenizer = tokenizer \n",
" \n",
" def __getitem__(self, index: int):\n",
" path, target = self.samples[index]\n",
" \n",
" name = Path(path).name\n",
" Extra_data = self.Extra_Data[self.Extra_Data[\"raw_name\"] == name].fillna('').to_dict(\"records\")[0]\n",
" prompt_text = PROMPT[0].format(**Extra_data)\n",
" \n",
" text_data = self.tokenizer(prompt_text, padding=\"max_length\", truncation=True, max_length=128, return_tensors=\"pt\")\n",
" \n",
" sample = self.loader(path) \n",
" if self.transform is not None:\n",
" sample = self.transform(sample)\n",
" if self.target_transform is not None:\n",
" target = self.target_transform(target)\n",
"\n",
" return (sample, text_data), target\n",
"\n",
"class RARP_CLIP_loss(nn.Module):\n",
" def __init__(self, *args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n",
" self.lossFN = torch.nn.BCEWithLogitsLoss()\n",
" \n",
" def forward(self, clip_logits:torch.Tensor, label:torch.Tensor):\n",
" image_logits, text_logits = clip_logits\n",
" \n",
" image_logits = image_logits.flatten()\n",
" text_logits = text_logits.flatten()\n",
" \n",
" loss = (self.lossFN(image_logits, label) + self.lossFN(text_logits, label)) / 2\n",
" \n",
" return loss\n",
"\n",
"class RARP_CLIP(nn.Module):\n",
" def __init__(self, text_encoder:nn.Module, text_output_feat_dim:int, img_output_feat_dim:int, embed_dim:int, latent_space_dim:int):\n",
" super().__init__()\n",
" \n",
" self.image_latent_space = nn.Sequential(\n",
" nn.Linear(img_output_feat_dim, embed_dim),\n",
" nn.LayerNorm(embed_dim),\n",
" nn.GELU(),\n",
" nn.Linear(embed_dim, latent_space_dim)\n",
" )\n",
" \n",
" self.text_encoder = text_encoder\n",
" self.text_latent_space = nn.Sequential(\n",
" nn.Linear(text_output_feat_dim, embed_dim),\n",
" nn.LayerNorm(embed_dim),\n",
" nn.GELU(),\n",
" nn.Linear(embed_dim, latent_space_dim)\n",
" )\n",
" \n",
" self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1/0.07)))\n",
" \n",
" def forward(self, image_emb, text_tokens):\n",
" \n",
" x_img = self.image_latent_space(image_emb)\n",
" \n",
" x_text = self.text_encoder(\n",
" input_ids=text_tokens[\"input_ids\"],\n",
" attention_mask=text_tokens[\"attention_mask\"]\n",
" )\n",
" x_text = x_text.last_hidden_state[:,0] #[CLS] token\n",
" x_text = self.text_latent_space(x_text)\n",
" \n",
" #Normalize\n",
" x_img = F.normalize(x_img, dim=-1)\n",
" x_text = F.normalize(x_text, dim=-1)\n",
" \n",
" #Scaled cosine logits\n",
" scale = self.logit_scale.exp()\n",
" logits_img = scale * x_img @ x_text.t()\n",
" logits_text = logits_img.t()\n",
" \n",
" return logits_img, logits_text\n",
" \n",
"class RARP_VAN_BERT(L.LightningModule):\n",
" def __init__(\n",
" self, \n",
" bert_model_name:str,\n",
" van_model:str,\n",
" lr:float = 1e-4,\n",
" latent_space_dim:int = 512,\n",
" hiden_dim:int = 256,\n",
" clasiffier_layers = []\n",
" ):\n",
" super().__init__()\n",
" \n",
" self.save_hyperparameters(ignore=[\"bert_model_name\", \"van_model\"])\n",
" \n",
" self.train_acc = torchmetrics.Accuracy('binary')\n",
" self.val_acc = torchmetrics.Accuracy('binary')\n",
" \n",
" self.bert_llm = AutoModel.from_pretrained(bert_model_name)\n",
" self.text_emb = 768\n",
" \n",
" for parms in self.bert_llm.parameters():\n",
" parms.requires_grad = False\n",
" \n",
" self.van_encoder = van.van_b2(pretrained=False, num_classes=0)\n",
" self.van_encoder.load_state_dict(torch.load(van_model))\n",
" self.image_emb = 512\n",
" \n",
" self.clip = RARP_CLIP(self.bert_llm, self.text_emb, self.image_emb, hiden_dim, latent_space_dim)\n",
" \n",
" self.clasiffier = M.RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))\n",
" \n",
" self.lossFN_clasiffier = torch.nn.BCEWithLogitsLoss()\n",
" self.lossFN_CLIP = RARP_CLIP_loss()\n",
" \n",
" \n",
" def forward(self, data):\n",
" img_data, text_data = data\n",
" img_data = img_data.float()\n",
" \n",
" img_features = self.van_encoder(img_data)\n",
" \n",
" logits_img, logits_text = self.clip(img_features, text_data)\n",
" \n",
" pred = self.clasiffier(img_features)\n",
" \n",
" return pred, (logits_img, logits_text)\n",
" \n",
" def _shared_step(self, batch):\n",
" data, label = batch\n",
" label = label.float()\n",
" \n",
" prediction, clip_logits = self(data)\n",
" \n",
" prediction = prediction.flatten()\n",
" predicted_labels = torch.sigmoid(prediction)\n",
" \n",
" loss = self.lossFN_clasiffier(prediction, label) + self.lossFN_CLIP(clip_logits, label)\n",
" \n",
" return loss, label, predicted_labels\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" loss, true_labels, pred_labels = self._shared_step(batch)\n",
" \n",
" self.log(\"train_loss\", loss, on_epoch=True)\n",
" self.train_acc.update(pred_labels, true_labels)\n",
" self.log(\"train_acc\", self.train_acc, on_epoch=True, on_step=False)\n",
" \n",
" return loss\n",
" \n",
" def validation_step(self, batch, batch_idx):\n",
" loss, true_labels, pred_labels = self._shared_step(batch)\n",
" \n",
" self.log(\"val_loss\", loss, on_epoch=True, on_step=False, prog_bar=True)\n",
" self.val_acc.update(pred_labels, true_labels)\n",
" self.log(\"val_acc\", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr) #, weight_decay=self.Lambda_L2\n",
" \n",
" return [optimizer]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17435249",
"metadata": {},
"outputs": [],
"source": [
"Fold = 0\n",
"InitResize=(512, 512)\n",
"\n",
"batchSize = 16\n",
"numWorkers = 0\n",
"MaxEpochs = 100\n",
"LogFileName = \"logs_debug\"\n",
"\n",
"rootFile = Dataset.CVS_File.parent.parent/f\"fold_{Fold}\"\n",
"checkPtCallback = callbk.ModelCheckpoint(monitor='val_acc', filename=\"RARP-{epoch}\", save_top_k=10, mode='max')\n",
"\n",
"Model = RARP_VAN_BERT(LLM, \"van_b2_teacher_98.pth\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0245feb",
"metadata": {},
"outputs": [],
"source": [
"traintransform = torch.nn.Sequential(\n",
" transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),\n",
" transforms.RandomErasing(0.8, value=\"random\"),\n",
" transforms.RandomAffine(degrees=(-45, 45), scale=(0.8, 1.2), fill=5),\n",
" transforms.GaussianBlur(5),\n",
" transforms.RandomCrop(224),\n",
" transforms.Normalize(Dataset.mean, Dataset.std)\n",
").to(device)\n",
"\n",
"valtransform = torch.nn.Sequential(\n",
" transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),\n",
" transforms.CenterCrop(224),\n",
" transforms.Normalize(Dataset.mean, Dataset.std)\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78b32c2c",
"metadata": {},
"outputs": [],
"source": [
"bert_tokenizer = AutoTokenizer.from_pretrained(LLM)\n",
"\n",
"trainDataset = RARP_DatasetFolder_LLM(\n",
" str (rootFile/\"train\"),\n",
" loader = defs.load_file_tensor,\n",
" Extra_Data = NewData,\n",
" tokenizer = bert_tokenizer,\n",
" extensions = \"npy\",\n",
" transform = traintransform\n",
")\n",
"\n",
"valDataset = RARP_DatasetFolder_LLM(\n",
" str (rootFile/\"val\"),\n",
" loader = defs.load_file_tensor,\n",
" Extra_Data = NewData,\n",
" tokenizer = bert_tokenizer,\n",
" extensions = \"npy\",\n",
" transform = valtransform\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0155e78a",
"metadata": {},
"outputs": [],
"source": [
"Train_DataLoader = DataLoader(\n",
" trainDataset, \n",
" batch_size=batchSize, \n",
" num_workers=numWorkers, \n",
" shuffle=True, \n",
" drop_last=True,\n",
" pin_memory=True,\n",
" persistent_workers=numWorkers>0\n",
")\n",
"Val_DataLoader = DataLoader(\n",
" valDataset, \n",
" batch_size=batchSize, \n",
" num_workers=numWorkers, \n",
" shuffle=False, \n",
" pin_memory=True,\n",
" persistent_workers=numWorkers>0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d261d6b",
"metadata": {},
"outputs": [],
"source": [
"trainer = L.Trainer(\n",
" deterministic=True,\n",
" accelerator='gpu', \n",
" devices=1, \n",
" logger=TensorBoardLogger(save_dir=f\"./{LogFileName}\"),\n",
" log_every_n_steps=5, \n",
" callbacks=[checkPtCallback],\n",
" max_epochs=MaxEpochs,\n",
")\n",
"print(\"Train Phase\")\n",
"trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f82b862",
"metadata": {},
"outputs": [],
"source": [
"sample_dic = [{\n",
" \"age\": 64,\n",
" \"BMI\": 19.8,\n",
" \"PSA\": 14.31,\n",
" \"cT\": \"2a\",\n",
" \"GS\": \"3+4\",\n",
" \"prostate_size\": 14,\n",
" \"surgery_time\": 275,\n",
" \"console_time\": 213,\n",
" \"pT\": \"2c\",\n",
" \"blood_loss\": 200,\n",
" \"NVB\": \"YES\"\n",
"}]\n",
"\n",
"sample_prompt = PROMPT[0].format(**sample_dic[0])\n",
"print (len(sample_prompt))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83ad1595",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(LLM)\n",
"bert = AutoModel.from_pretrained(LLM)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28dc51dc",
"metadata": {},
"outputs": [],
"source": [
"text_inputs = tokenizer(sample_prompt, padding=\"max_length\", truncation=True, max_length=128, return_tensors=\"pt\")\n",
"text_inputs[\"input_ids\"].shape\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "edcecf9d",
"metadata": {},
"outputs": [],
"source": [
"x = bert(input_ids=text_inputs[\"input_ids\"], attention_mask=text_inputs[\"attention_mask\"])\n",
"x.last_hidden_state[:,0]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pyRARP",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}