import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
def Calc_Eval_table(
TrainModel:M.RARP_NVB_Model,
TestDataLoadre:DataLoader,
Youden=False,
modelName="",
):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
with torch.no_grad():
for data, label in tqdm(iter(TestDataLoadre)):
data = data.float().to(device)
label = label.to(device)
pred, *_ = TrainModel(data)
pred = pred.flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
#print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
]
if Youden:
for i in range(2):
aucCurve = torchmetrics.ROC("binary").to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
return table
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
parser.add_argument("--Fold", type=int, default=0)
parser.add_argument("-lv","--Log_version", type=int, default=None)
parser.add_argument("--Workers", type=int, default=0)
parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
parser.add_argument("--Head", type=int, default=None)
parser.add_argument("-me", "--maxEpochs", type=int, default=None)
parser.add_argument("-b", "--Batch_size", type=int, default=8)
parser.add_argument("--GPU", type=int, default=0)
args = parser.parse_args()
setup_seed(2023)
device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")
df = pd.read_csv("./Dataset_RARP_video/dataset_videos_frames_folds.csv")
df = df.loc[df["type"] == "f"]
FOLD = args.Fold
WORKERS = args.Workers
BATCH_SIZE = args.Batch_size
MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
Mean = [30.38144216, 42.03988769, 97.8896116]
Std = [40.63141752, 44.26910074, 50.29294373]
print(f"Fold_{FOLD}")
train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
valtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Mean, Std)
).to(device)
testtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Mean, Std)
).to(device)
TrainDINOTransforms = Loaders.RARP_DINO_Augmentation(
GloblaCropsScale = (0.4, 1),
LocalCropsScale = (0.05, 0.4),
NumLocalCrops = 4,
Size = 224,
device = device,
mean = Mean,
std = Std,
Tranform_0 = testtransform
)
train_dataset = Loaders.RARP_Video_Frames_Dataset(train_set, TrainDINOTransforms, True)
val_dataset = Loaders.RARP_Video_Frames_Dataset(val_set, valtransform, True)
test_dataset = Loaders.RARP_Video_Frames_Dataset(test_set, testtransform, True)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
test_loader = DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
Model = M.RARP_NVB_DINO_MultiTask_A5_MAE(
M.TypeLossFunction.BCEWithLogits,
std=Std,
mean=Mean,
L1= 1.31E-04,
L2= 0,
lr= 1e-4,
SoftAdptAlgo=0
)
print(f"Model Used: {type(Model).__name__}")
LogFileName = f"{args.Log_Name}"
checkPtCallback = [
callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
]
match(args.Phase):
case "train":
trainer = L.Trainer(
deterministic=True,
accelerator="gpu",
devices=[args.GPU],
logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
log_every_n_steps=5,
callbacks=checkPtCallback,
max_epochs=MAX_EPOCHS
)
print("Train Phase")
trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
case "eval_all":
print("Evaluation Phase")
rows = []
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
for ckpFile in pathCkptFile.glob("*.ckpt"):
print(ckpFile.name)
Model = M.RARP_NVB_DINO_MultiTask_A5_MAE.load_from_checkpoint(ckpFile)
temp = Calc_Eval_table(
Model,
test_loader,
True,
ckpFile.name
)
rows += temp
df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])
df.style.highlight_max(color="red", axis=0)
print(df)