import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from transformers import AutoTokenizer, AutoModel
import van
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger
import Models as M
from pathlib import Path
import pandas as pd
import Loaders
import numpy as np
import torchmetrics
import defs
import argparse
from tqdm import tqdm
from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt
LLM = "emilyalsentzer/Bio_ClinicalBERT"
PROMPT = [(
"Post-prostatectomy robotic laparoscopic view of the pelvic surgical bed in a {age}-year-old patient "
"(BMI {BMI}, PSA {PSA} ng/mL) with clinical stage {cT}, pathologic stage {pT} and Gleason score {GS}; "
"prostate size {prostate_size} mm; operating time was {surgery_time} min (console time {console_time} min); "
"blood loss {blood_loss} mL. Neurovascular bundle preserved: {NVB}"
)]
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
class RARP_DatasetFolder_LLM(torchvision.datasets.DatasetFolder):
def __init__(self,
root: str,
loader,
Extra_Data: pd.DataFrame,
tokenizer = None,
extensions = None,
transform = None,
target_transform = None,
is_valid_file = None
) -> None:
super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)
self.Extra_Data = Extra_Data
self.tokenizer = tokenizer
def __getitem__(self, index: int):
path, target = self.samples[index]
name = Path(path).name
Extra_data = self.Extra_Data[self.Extra_Data["raw_name"] == name].fillna('').to_dict("records")[0]
prompt_text = PROMPT[0].format(**Extra_data)
text_data = self.tokenizer(prompt_text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
for k in text_data:
text_data[k] = text_data[k].squeeze(0)
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return (sample, text_data), target
class RARP_CLIP_loss(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lossFN = torch.nn.CrossEntropyLoss()
def forward(self, clip_logits:torch.Tensor):
image_logits, text_logits = clip_logits
label = torch.arange(image_logits.size(0), device=image_logits.device)
loss = (self.lossFN(image_logits, label) + self.lossFN(text_logits, label)) / 2
return loss
class RARP_CLIP(nn.Module):
def __init__(
self,
text_output_feat_dim:int,
img_output_feat_dim:int,
embed_dim:int,
latent_space_dim:int
):
super().__init__()
self.image_latent_space = M.RARP_NVB_MLP(img_output_feat_dim, latent_space_dim, n_layers=2)
self.text_latent_space = M.RARP_NVB_MLP(text_output_feat_dim, latent_space_dim, n_layers=2)
self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1/0.07)))
def forward(self, image_emb, text_emb):
x_img = self.image_latent_space(image_emb)
x_text = self.text_latent_space(text_emb)
#Normalize
x_img = F.normalize(x_img, dim=-1)
x_text = F.normalize(x_text, dim=-1)
#Scaled cosine logits
scale = self.logit_scale.exp()
logits_img = scale * x_img @ x_text.t()
logits_text = logits_img.t()
return logits_img, logits_text
class RARP_VAN_BERT(L.LightningModule):
def __init__(
self,
bert_model_name:str = LLM,
van_model:str = "",
lr:float = 1e-4,
latent_space_dim:int = 512,
hiden_dim:int = 256,
clasiffier_layers = [],
softAdptAlgo:int = 0,
softAdptBeta:float = 0.1,
lambda_L1:float = 1.31E-04,
):
super().__init__()
self.save_hyperparameters(ignore=["bert_model_name", "van_model"])
self.loss_weights = [1, 1]
self.softAdapt = NormalizedSoftAdapt(softAdptBeta) if softAdptAlgo == 1 else LossWeightedSoftAdapt(softAdptBeta)
self.loss_history = {
0 : [], #'loss_Binary'
1 : [], #'loss_CLIP'
}
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
self.bert_llm = AutoModel.from_pretrained(bert_model_name)
self.text_emb = 768
for parms in self.bert_llm.parameters():
parms.requires_grad = False
self.van_encoder = van.van_b2(pretrained=False, num_classes=0)
self.van_encoder.load_state_dict(torch.load(van_model))
self.image_emb = 512
self.clip = RARP_CLIP(self.text_emb, self.image_emb, hiden_dim, latent_space_dim)
self.clasiffier = M.RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))
self.lossFN_clasiffier = torch.nn.BCEWithLogitsLoss()
self.lossFN_CLIP = RARP_CLIP_loss()
def forward(self, data):
img_data, text_data = data
img_data = img_data.float()
ids_tokens = text_data["input_ids"]
att_mask = text_data["attention_mask"]
img_features = self.van_encoder(img_data)
x_text = self.bert_llm(
input_ids=ids_tokens,
attention_mask=att_mask
)
x_text = x_text.last_hidden_state[:,0] #[CLS] token
logits_img, logits_text = self.clip(img_features, x_text)
pred = self.clasiffier(img_features)
return pred, (logits_img, logits_text)
def _calc_L1(self, params):
l1 = 0
for p in params:
l1 += torch.sum(torch.abs(p))
return self.hparams.lambda_L1 * l1
def _calc_weights(self, log_weights:bool = True):
self.loss_weights = self.softAdapt.get_component_weights(
torch.tensor(self.loss_history[0][:-1] if len(self.loss_history[0]) % 2 == 0 else self.loss_history[0]),
torch.tensor(self.loss_history[1][:-1] if len(self.loss_history[1]) % 2 == 0 else self.loss_history[1]),
verbose=False
)
if log_weights:
self.log("W_loss_img", self.loss_weights[0], on_epoch=True, on_step=False)
self.log("W_loss_CLIP", self.loss_weights[1], on_epoch=True, on_step=False)
self.loss_history = {
0: [],
1: [],
}
def _shared_step(self, batch, val_step:bool = False):
data, label = batch
label = label.float()
prediction, clip_logits = self(data)
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
loss_list = [
self.lossFN_clasiffier(prediction, label),
self.lossFN_CLIP(clip_logits)
]
if not val_step:
if self.hparams.lambda_L1 is not None:
loss_list[0] += self._calc_L1(self.clasiffier.parameters())
for i, l in enumerate(loss_list):
self.loss_history[i].append(l.item())
loss = 0
for l, w in zip(loss_list, self.loss_weights):
loss += w * l
return loss, label, predicted_labels, loss_list
def on_train_epoch_start(self):
if self.current_epoch % 2 == 0 and self.current_epoch != 0:
self._calc_weights()
def training_step(self, batch, batch_idx):
loss, true_labels, pred_labels, losses = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(pred_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_loss_img", losses[0], on_epoch=True, on_step=False)
self.log("train_loss_clip", losses[1], on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, pred_labels, losses = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(pred_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_loss_img", losses[0], on_epoch=True, on_step=False)
self.log("val_loss_clip", losses[1], on_epoch=True, on_step=False)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels, _ = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr) #, weight_decay=self.Lambda_L2
return [optimizer]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
parser.add_argument("--Fold", type=int, default=0)
parser.add_argument("-lv", "--Log_version", type=int)
args = parser.parse_args()
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All_no20Crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_AndoAll20_crop",
Fold=5,
removeBlackBar=False,
)
Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
Dataset.CreateFolds()
setup_seed(2023)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DumpCSV = pd.read_csv(Dataset.CVS_File)
Extradata = pd.read_csv(Path("./DataSet_Ando_All_no20Crop/data_source_prompt.csv"))
DumpCSV["raw_name"] = "Img0-" + DumpCSV["id"].astype(str) + ".npy"
DumpCSV = DumpCSV.drop(columns=["id", "path", "mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3"])
NewData = pd.merge(Extradata, DumpCSV, on="name")
Fold = args.Fold
InitResize=(512, 512)
batchSize = 16
numWorkers = 5
MaxEpochs = 100
LogFileName = "log_X22"
rootFile = Dataset.CVS_File.parent.parent/f"fold_{Fold}"
checkPtCallback = [
callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
]
traintransform = torch.nn.Sequential(
transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomRotation(
degrees=(-15, 15),
fill=5
),
transforms.RandomResizedCrop(
224,
scale=(0.4, 1),
antialias=True,
interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))
], 0.3),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
valtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
testtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
bert_tokenizer = AutoTokenizer.from_pretrained(LLM)
trainDataset = RARP_DatasetFolder_LLM(
str (rootFile/"train"),
loader = defs.load_file_tensor,
Extra_Data = NewData,
tokenizer = bert_tokenizer,
extensions = "npy",
transform = traintransform
)
valDataset = RARP_DatasetFolder_LLM(
str (rootFile/"val"),
loader = defs.load_file_tensor,
Extra_Data = NewData,
tokenizer = bert_tokenizer,
extensions = "npy",
transform = valtransform
)
testDataset = RARP_DatasetFolder_LLM(
str (rootFile/"test"),
loader=defs.load_file_tensor,
Extra_Data = NewData,
tokenizer = bert_tokenizer,
extensions="npy",
transform=testtransform
)
Train_DataLoader = DataLoader(
trainDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=True,
drop_last=True,
pin_memory=True,
persistent_workers=numWorkers>0
)
Val_DataLoader = DataLoader(
valDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
Test_DataLoader = DataLoader(
testDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
match(args.Phase):
case "train":
Model = RARP_VAN_BERT(LLM, "van_b2_teacher_90_D2.pth", clasiffier_layers=[])
trainer = L.Trainer(
deterministic=True,
accelerator='gpu',
devices=1,
logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
log_every_n_steps=5,
callbacks=checkPtCallback,
max_epochs=MaxEpochs,
)
print("Train Phase")
trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")
case "eval_all":
print("Evaluation Phase")
rows = []
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
for ckpFile in pathCkptFile.glob("*.ckpt"):
print(ckpFile.name)
Model = RARP_VAN_BERT.load_from_checkpoint(ckpFile)
Model.to(device)
Model.eval()
Predictions = []
Labels = []
with torch.no_grad():
for data, label in tqdm(iter(Test_DataLoader)):
data = data.float().to(device)
label = label.to(device)
pred, _= Model(data)
pred = pred.flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
]
for i in range(2):
aucCurve = torchmetrics.ROC("binary").to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([
f"{th1:.4f}",
f"{accY.item():.4f}",
f"{precisionY.item():.4f}",
f"{recallY.item():.4f}",
f"{f1ScoreY.item():.4f}",
f"{auc.item():.4f}",
f"{specifictyY.item():.4f}",
ckpFile.name
])
rows += table
df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])
df.style.highlight_max(color="red", axis=0)
print(df)