Newer
Older
RARP / Clip_van_text.ipynb
@delAguila delAguila 27 days ago 15 KB Final Commit.
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7b9e30c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import van\n",
    "import lightning as L\n",
    "from lightning.pytorch import seed_everything\n",
    "import lightning.pytorch.callbacks as callbk\n",
    "from lightning.pytorch.loggers import TensorBoardLogger\n",
    "import Models as M\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "import Loaders\n",
    "import numpy as np\n",
    "import torchmetrics\n",
    "import defs\n",
    "\n",
    "\n",
    "LLM = \"emilyalsentzer/Bio_ClinicalBERT\"\n",
    "\n",
    "PROMPT = [(\n",
    "    \"Post-prostatectomy robotic laparoscopic view of the pelvic surgical bed in a {age}-year-old patient \"\n",
    "    \"(BMI {BMI}, PSA {PSA} ng/mL) with clinical stage {cT}, pathologic stage {pT} and Gleason score {GS}; \"\n",
    "    \"prostate size {prostate_size} mm; operating time was {surgery_time} min (console time {console_time} min); \"\n",
    "    \"blood loss {blood_loss} mL. Neurovascular bundle preserved: {NVB}\"\n",
    ")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dcc3aca",
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(seed)\n",
    "    seed_everything(seed, workers=True)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "Dataset = Loaders.RARP_DatasetCreator(\n",
    "    \"./DataSet_Ando_All_no20Crop\",\n",
    "    FoldSeed=505,\n",
    "    createFile=True,\n",
    "    SavePath=\"./DataSet_AndoAll20_crop\",\n",
    "    Fold=5,\n",
    "    removeBlackBar=False,\n",
    ")\n",
    "\n",
    "Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])\n",
    "\n",
    "Dataset.CreateFolds()\n",
    "    \n",
    "setup_seed(2023)\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "DumpCSV = pd.read_csv(Dataset.CVS_File)\n",
    "Extradata = pd.read_csv(Path(\"./DataSet_Ando_All_no20Crop/data_source_prompt.csv\"))\n",
    "\n",
    "DumpCSV[\"raw_name\"] = \"Img0-\" + DumpCSV[\"id\"].astype(str) + \".npy\"\n",
    "DumpCSV = DumpCSV.drop(columns=[\"id\", \"path\", \"mean_1\", \"mean_2\", \"mean_3\", \"std_1\", \"std_2\", \"std_3\"])\n",
    "\n",
    "NewData = pd.merge(Extradata, DumpCSV, on=\"name\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c806faed",
   "metadata": {},
   "outputs": [],
   "source": [
    "NewData[NewData[\"raw_name\"] == \"Img0-33.npy\"].fillna('').to_dict(\"records\")[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5a13167",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RARP_DatasetFolder_LLM(torchvision.datasets.DatasetFolder):\n",
    "    def __init__(self, \n",
    "                 root: str, \n",
    "                 loader, \n",
    "                 Extra_Data: pd.DataFrame,\n",
    "                 tokenizer = None, \n",
    "                 extensions = None, \n",
    "                 transform = None, \n",
    "                 target_transform = None, \n",
    "                 is_valid_file = None\n",
    "                ) -> None:\n",
    "        super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)\n",
    "        self.Extra_Data = Extra_Data\n",
    "        self.tokenizer = tokenizer \n",
    "        \n",
    "    def __getitem__(self, index: int):\n",
    "        path, target = self.samples[index]\n",
    "        \n",
    "        name = Path(path).name\n",
    "        Extra_data = self.Extra_Data[self.Extra_Data[\"raw_name\"] == name].fillna('').to_dict(\"records\")[0]\n",
    "        prompt_text = PROMPT[0].format(**Extra_data)\n",
    "        \n",
    "        text_data = self.tokenizer(prompt_text, padding=\"max_length\", truncation=True, max_length=128, return_tensors=\"pt\")\n",
    "        \n",
    "        sample = self.loader(path) \n",
    "        if self.transform is not None:\n",
    "            sample = self.transform(sample)\n",
    "        if self.target_transform is not None:\n",
    "            target = self.target_transform(target)\n",
    "\n",
    "        return (sample, text_data), target\n",
    "\n",
    "class RARP_CLIP_loss(nn.Module):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.lossFN = torch.nn.BCEWithLogitsLoss()\n",
    "        \n",
    "    def forward(self, clip_logits:torch.Tensor, label:torch.Tensor):\n",
    "        image_logits, text_logits = clip_logits\n",
    "        \n",
    "        image_logits = image_logits.flatten()\n",
    "        text_logits = text_logits.flatten()\n",
    "        \n",
    "        loss = (self.lossFN(image_logits, label) + self.lossFN(text_logits, label)) / 2\n",
    "        \n",
    "        return loss\n",
    "\n",
    "class RARP_CLIP(nn.Module):\n",
    "    def __init__(self, text_encoder:nn.Module, text_output_feat_dim:int, img_output_feat_dim:int, embed_dim:int, latent_space_dim:int):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.image_latent_space = nn.Sequential(\n",
    "            nn.Linear(img_output_feat_dim, embed_dim),\n",
    "            nn.LayerNorm(embed_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(embed_dim, latent_space_dim)\n",
    "        )\n",
    "        \n",
    "        self.text_encoder = text_encoder\n",
    "        self.text_latent_space = nn.Sequential(\n",
    "            nn.Linear(text_output_feat_dim, embed_dim),\n",
    "            nn.LayerNorm(embed_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(embed_dim, latent_space_dim)\n",
    "        )\n",
    "        \n",
    "        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1/0.07)))\n",
    "        \n",
    "    def forward(self, image_emb, text_tokens):\n",
    "        \n",
    "        x_img = self.image_latent_space(image_emb)\n",
    "        \n",
    "        x_text = self.text_encoder(\n",
    "            input_ids=text_tokens[\"input_ids\"],\n",
    "            attention_mask=text_tokens[\"attention_mask\"]\n",
    "        )\n",
    "        x_text = x_text.last_hidden_state[:,0] #[CLS] token\n",
    "        x_text = self.text_latent_space(x_text)\n",
    "        \n",
    "        #Normalize\n",
    "        x_img = F.normalize(x_img, dim=-1)\n",
    "        x_text = F.normalize(x_text, dim=-1)\n",
    "        \n",
    "        #Scaled cosine logits\n",
    "        scale = self.logit_scale.exp()\n",
    "        logits_img = scale * x_img @ x_text.t()\n",
    "        logits_text = logits_img.t()\n",
    "        \n",
    "        return logits_img, logits_text\n",
    "        \n",
    "class RARP_VAN_BERT(L.LightningModule):\n",
    "    def __init__(\n",
    "        self, \n",
    "        bert_model_name:str,\n",
    "        van_model:str,\n",
    "        lr:float = 1e-4,\n",
    "        latent_space_dim:int = 512,\n",
    "        hiden_dim:int = 256,\n",
    "        clasiffier_layers = []\n",
    "    ):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.save_hyperparameters(ignore=[\"bert_model_name\", \"van_model\"])\n",
    "        \n",
    "        self.train_acc = torchmetrics.Accuracy('binary')\n",
    "        self.val_acc = torchmetrics.Accuracy('binary')\n",
    "        \n",
    "        self.bert_llm = AutoModel.from_pretrained(bert_model_name)\n",
    "        self.text_emb = 768\n",
    "        \n",
    "        for parms in self.bert_llm.parameters():\n",
    "            parms.requires_grad = False\n",
    "        \n",
    "        self.van_encoder = van.van_b2(pretrained=False, num_classes=0)\n",
    "        self.van_encoder.load_state_dict(torch.load(van_model))\n",
    "        self.image_emb = 512\n",
    "                \n",
    "        self.clip = RARP_CLIP(self.bert_llm, self.text_emb, self.image_emb, hiden_dim, latent_space_dim)\n",
    "        \n",
    "        self.clasiffier = M.RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))\n",
    "        \n",
    "        self.lossFN_clasiffier = torch.nn.BCEWithLogitsLoss()\n",
    "        self.lossFN_CLIP = RARP_CLIP_loss()\n",
    "        \n",
    "        \n",
    "    def forward(self, data):\n",
    "        img_data, text_data = data\n",
    "        img_data = img_data.float()\n",
    "        \n",
    "        img_features = self.van_encoder(img_data)\n",
    "        \n",
    "        logits_img, logits_text = self.clip(img_features, text_data)\n",
    "        \n",
    "        pred = self.clasiffier(img_features)\n",
    "        \n",
    "        return pred, (logits_img, logits_text)\n",
    "    \n",
    "    def _shared_step(self, batch):\n",
    "        data, label = batch\n",
    "        label = label.float()\n",
    "        \n",
    "        prediction, clip_logits = self(data)\n",
    "        \n",
    "        prediction = prediction.flatten()\n",
    "        predicted_labels = torch.sigmoid(prediction)\n",
    "        \n",
    "        loss = self.lossFN_clasiffier(prediction, label) + self.lossFN_CLIP(clip_logits, label)\n",
    "        \n",
    "        return loss, label, predicted_labels\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        loss, true_labels, pred_labels = self._shared_step(batch)\n",
    "        \n",
    "        self.log(\"train_loss\", loss, on_epoch=True)\n",
    "        self.train_acc.update(pred_labels, true_labels)\n",
    "        self.log(\"train_acc\", self.train_acc, on_epoch=True, on_step=False)\n",
    "        \n",
    "        return loss\n",
    "    \n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        loss, true_labels, pred_labels = self._shared_step(batch)\n",
    "        \n",
    "        self.log(\"val_loss\", loss, on_epoch=True, on_step=False, prog_bar=True)\n",
    "        self.val_acc.update(pred_labels, true_labels)\n",
    "        self.log(\"val_acc\", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)\n",
    "        \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)  #, weight_decay=self.Lambda_L2\n",
    "        \n",
    "        return [optimizer]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17435249",
   "metadata": {},
   "outputs": [],
   "source": [
    "Fold = 0\n",
    "InitResize=(512, 512)\n",
    "\n",
    "batchSize = 16\n",
    "numWorkers = 0\n",
    "MaxEpochs = 100\n",
    "LogFileName = \"logs_debug\"\n",
    "\n",
    "rootFile = Dataset.CVS_File.parent.parent/f\"fold_{Fold}\"\n",
    "checkPtCallback = callbk.ModelCheckpoint(monitor='val_acc', filename=\"RARP-{epoch}\", save_top_k=10, mode='max')\n",
    "\n",
    "Model = RARP_VAN_BERT(LLM, \"van_b2_teacher_98.pth\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0245feb",
   "metadata": {},
   "outputs": [],
   "source": [
    "traintransform = torch.nn.Sequential(\n",
    "    transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),\n",
    "    transforms.RandomErasing(0.8, value=\"random\"),\n",
    "    transforms.RandomAffine(degrees=(-45, 45), scale=(0.8, 1.2), fill=5),\n",
    "    transforms.GaussianBlur(5),\n",
    "    transforms.RandomCrop(224),\n",
    "    transforms.Normalize(Dataset.mean, Dataset.std)\n",
    ").to(device)\n",
    "\n",
    "valtransform = torch.nn.Sequential(\n",
    "    transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.Normalize(Dataset.mean, Dataset.std)\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78b32c2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_tokenizer = AutoTokenizer.from_pretrained(LLM)\n",
    "\n",
    "trainDataset = RARP_DatasetFolder_LLM(\n",
    "    str (rootFile/\"train\"),\n",
    "    loader = defs.load_file_tensor,\n",
    "    Extra_Data = NewData,\n",
    "    tokenizer = bert_tokenizer,\n",
    "    extensions = \"npy\",\n",
    "    transform = traintransform\n",
    ")\n",
    "\n",
    "valDataset = RARP_DatasetFolder_LLM(\n",
    "    str (rootFile/\"val\"),\n",
    "    loader = defs.load_file_tensor,\n",
    "    Extra_Data = NewData,\n",
    "    tokenizer = bert_tokenizer,\n",
    "    extensions = \"npy\",\n",
    "    transform = valtransform\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0155e78a",
   "metadata": {},
   "outputs": [],
   "source": [
    "Train_DataLoader = DataLoader(\n",
    "    trainDataset, \n",
    "    batch_size=batchSize, \n",
    "    num_workers=numWorkers, \n",
    "    shuffle=True, \n",
    "    drop_last=True,\n",
    "    pin_memory=True,\n",
    "    persistent_workers=numWorkers>0\n",
    ")\n",
    "Val_DataLoader = DataLoader(\n",
    "    valDataset, \n",
    "    batch_size=batchSize, \n",
    "    num_workers=numWorkers, \n",
    "    shuffle=False, \n",
    "    pin_memory=True,\n",
    "    persistent_workers=numWorkers>0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d261d6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = L.Trainer(\n",
    "    deterministic=True,\n",
    "    accelerator='gpu', \n",
    "    devices=1, \n",
    "    logger=TensorBoardLogger(save_dir=f\"./{LogFileName}\"),\n",
    "    log_every_n_steps=5,   \n",
    "    callbacks=[checkPtCallback],\n",
    "    max_epochs=MaxEpochs,\n",
    ")\n",
    "print(\"Train Phase\")\n",
    "trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f82b862",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_dic = [{\n",
    "    \"age\": 64,\n",
    "    \"BMI\": 19.8,\n",
    "    \"PSA\": 14.31,\n",
    "    \"cT\": \"2a\",\n",
    "    \"GS\": \"3+4\",\n",
    "    \"prostate_size\": 14,\n",
    "    \"surgery_time\": 275,\n",
    "    \"console_time\": 213,\n",
    "    \"pT\": \"2c\",\n",
    "    \"blood_loss\": 200,\n",
    "    \"NVB\": \"YES\"\n",
    "}]\n",
    "\n",
    "sample_prompt = PROMPT[0].format(**sample_dic[0])\n",
    "print (len(sample_prompt))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ad1595",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(LLM)\n",
    "bert   = AutoModel.from_pretrained(LLM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28dc51dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_inputs = tokenizer(sample_prompt, padding=\"max_length\", truncation=True, max_length=128, return_tensors=\"pt\")\n",
    "text_inputs[\"input_ids\"].shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edcecf9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = bert(input_ids=text_inputs[\"input_ids\"], attention_mask=text_inputs[\"attention_mask\"])\n",
    "x.last_hidden_state[:,0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pyRARP",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}