import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import torchmetrics
import lightning as L
from lightning.pytorch import seed_everything
from lightning.pytorch.tuner import Tuner
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import Loaders
import defs
import argparse
import seaborn as sn
import Models as M
import pandas as pd
import warnings
from ultralytics import YOLO
import yaml
import optuna
from optuna.integration import PyTorchLightningPruningCallback
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True
def objective(trail: optuna.trial.Trial) -> float:
lr = trail.suggest_float("lr", 1e-4, 1e-3, log=True)
l1 = trail.suggest_float("L1", 1e-6, 1e-3, log=True)
Alpha = trail.suggest_float("W_Apha", 0, 1, step=0.05)
Thao = trail.suggest_float("Thao_KD", 1, 7, step=0.25)
Trainer_OP = L.Trainer(
logger=TensorBoardLogger(save_dir=f"./{LogFileName}", name="Tune"),
#enable_checkpointing=False,
max_epochs=MaxEpochs,
accelerator="auto",
log_every_n_steps=5,
devices=1,
callbacks=[PyTorchLightningPruningCallback(trail, monitor="val_acc"), callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=2, mode='max')],
)
hyperparameters = dict(
lr = lr,
L1 = l1,
Alpha = Alpha,
Beta = 1 - Alpha,
Thao = Thao
)
ModelOP, _ = getModel(
args.Model,
InitWeight,
TypeLoss,
OptConfig=hyperparameters
)
Trainer_OP.logger.log_hyperparams(hyperparameters)
Trainer_OP.fit(ModelOP, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
Trainer_OP.test(ModelOP, dataloaders=Test_DataLoader, ckpt_path="best")
return Trainer_OP.callback_metrics["test_acc"].item()
def Calc_Eval_table_New(TrainModel:M.RARP_NVB_Model):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
if isinstance(TrainModel, M.RARP_NVB_Model_test2):
with torch.no_grad():
for img, label in tqdm(iter(Test_DataLoader)):
img = img.float().to(device)
label = label.to(device)
pred = TrainModel(img)
Predictions.append(torch.softmax(pred, dim=1))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy("multiclass", num_classes=2).to(device)(Predictions, Labels)
precision = torchmetrics.Precision("multiclass", num_classes=2).to(device)(Predictions, Labels)
recall = torchmetrics.Recall("multiclass", num_classes=2).to(device)(Predictions, Labels)
auc = torchmetrics.AUROC("multiclass", num_classes=2).to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score("multiclass", num_classes=2).to(device)(Predictions, Labels)
return [acc.item(), precision.item(), recall.item(), f1Score.item(), auc.item()]
with torch.no_grad():
for img, label in tqdm(iter(Test_DataLoader)):
img = img.float().to(device)
label = label.float().to(device)
pred = TrainModel(img)
Predictions.append(torch.sigmoid(pred.squeeze(1)))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
#cm = torchmetrics.ConfusionMatrix('binary')(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
return [acc.item(), precision.item(), recall.item(), f1Score.item(), auc.item()]
def Calc_EvalMulticlass_table(TrainModel:M.RARP_NVB_Model,TestDataLoadre:DataLoader, Youden=False, modelName="", NumClasses:int=2):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
with torch.no_grad():
for data, label in tqdm(iter(TestDataLoadre)):
data = data.float().to(device)
label = label.to(device)
if isinstance(TrainModel, M.RARP_NVB_DINO_MultiTask):
pred, _, _ = TrainModel(data)
NumClasses = 4
else:
pred = TrainModel(data)
Predictions.append(torch.softmax(pred, dim=1))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels)
precision = torchmetrics.Precision("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels)
recall = torchmetrics.Recall("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels)
auc = torchmetrics.AUROC("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", f"*{modelName}*"] #.item():.4f
]
cm2 = torchmetrics.ConfusionMatrix("multiclass", num_classes=NumClasses).to(device)
cm2.update(Predictions, Labels)
_, ax = cm2.plot()
ax.set_title(f"NVB Classifier {modelName}")
if Youden:
for i in range(2):
aucCurve = torchmetrics.ROC("multiclass", num_classes=NumClasses).to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy("multiclass", num_classes=NumClasses, threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision("multiclass", num_classes=NumClasses, threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall("multiclass", num_classes=NumClasses, threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score("multiclass", num_classes=NumClasses, threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix("multiclass", num_classes=NumClasses, threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
return table
def Calc_Eval_table(
TrainModel:M.RARP_NVB_Model,
TestDataLoadre:DataLoader,
Youden=False, modelName="",
Add_TestDataset:DataLoader=None,
extraData:bool=False,
PseudoLabel:bool=True,
dataSetInfo:Loaders.RARP_DatasetCreator = None
):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
PseudoLabelTest = None
ban_ExtraImage = isinstance(testDataset, Loaders.RARP_DatasetFolder_DobleTransform)
with torch.no_grad():
for data, label in tqdm(iter(TestDataLoadre)):
if extraData:
img, extra = data
img = img.float().to(device)
extra = extra.float().to(device)
data = (img, extra)
elif ban_ExtraImage:
if len(data) == 3:
TData, Sdata, OData = data
data = (TData.float().to(device), Sdata.float().to(device), OData.float().to(device))
else:
TData, Sdata = data
data = (TData.float().to(device), Sdata.float().to(device))
else:
data = data.float().to(device)
label = label.to(device)
if isinstance(TrainModel, M.RARP_NVB_ResNet50_VAN):
pred, Plabel, _ = TrainModel(data)
pred = pred.flatten()
label = Plabel.int() if PseudoLabel else label
elif isinstance(TrainModel, (M.RARP_NVB_RN50_VAN_V2, M.RARP_NVB_DINO_MultiTask)):
pred, features, new_img = TrainModel(data)
pred = pred.flatten()
#_, axis = plt.subplots(2, 2, figsize=(9, 9))
#for i in range(2):
# for j in range(2):
# random_index = np.random.randint(0, len(new_img))
# img = new_img[random_index].cpu()
# img = img.numpy().transpose((1, 2, 0))
# img = np.clip((dataSetInfo.std * img + dataSetInfo.mean) / 255, 0, 1)
# img = img[...,::-1].copy()
# axis[i][j].imshow(img)
elif isinstance(TrainModel, M.RARP_NVB_DINO_RestNet50_Deep):
DK_PredLabels, _ = TrainModel(data)
pred, Plabel, _ = DK_PredLabels
label = Plabel.int() if PseudoLabel else label
else:
pred = TrainModel(data).flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
if Add_TestDataset is not None:
with torch.no_grad():
for data, label in tqdm(iter(Add_TestDataset)):
data = data.float().to(device)
label = label.to(device)
pred = TrainModel(data).flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
]
if Youden:
for i in range(2):
aucCurve = torchmetrics.ROC("binary").to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
return table
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
def CAM(model:M.RARP_NVB_Model, img:torch.Tensor, device):
with torch.no_grad():
img = img.to(device).float().unsqueeze(0)
if isinstance(model, M.RARP_NVB_VAN_CAM):
pred, feature = model(img, torch.tensor([0.0]).to(device))
else:
pred, feature = model(img)
_, c, h, w = feature.shape
feature = feature.reshape((c, h*w))
if isinstance(model, (M.RARP_NVB_ResNet18_CAM, M.RARP_NVB_ResNet50_CAM)):
wParams = list(model.model.fc.parameters())
elif isinstance(model, (M.RARP_NVB_MobileNetV2_CAM, M.RARP_NVB_EfficientNetV2_CAM)):
wParams = list(model.model.classifier.parameters())
elif isinstance(model, M.RARP_NVB_VAN_CAM):
wParams = list(model.model.head.parameters())
else:
raise "Cam Not Implemented"
pesos = wParams[0].detach()
cam = torch.matmul(pesos, feature)
cam = cam - torch.min(cam)
cam_img = cam / torch.max(cam)
cam_img = cam_img.reshape(h, w).cpu()
return cam_img, torch.sigmoid(pred)
def CAMVisualizer(img, heatmap, pred, label, mean, std, ax, row):
img = img.numpy().transpose((1, 2, 0))
heatmap = transforms.functional.resize(heatmap.unsqueeze(0), (img.shape[0], img.shape[1]), antialias=True)[0]
img = np.clip((std * img + mean) / 255, 0, 1)
img = img[...,::-1].copy()
col = 0
if row > 3:
col = 2
if row > 7:
col = 4
ax[row % 4][col + 0].imshow(img)
ax[row % 4][col + 0].axis('off')
ax[row % 4][col + 1].imshow(img)
ax[row % 4][col + 1].imshow(heatmap, alpha=0.5, cmap="jet")
ax[row % 4][col + 1].axis('off')
ax[row % 4][col + 1].set_title(f"Pred.: {pred.item():.4f}; Label: {label}")
#plt.title()
def ShowCAM(TrainedModel:M.RARP_NVB_Model, DataSet, mean, std, title=""):
TrainedModel.to(device)
TrainedModel.eval()
i = 0
params = {
"left":0,
"bottom":0.01,
"right":1,
"top":0.914,
"wspace":0,
"hspace":0.164
}
fig, axis = plt.subplots(4, 6, gridspec_kw=params)
with torch.no_grad():
if len(DataSet) > 12:
ix = np.unique(DataSet.targets, return_counts=True)[1]
NOTNVB_Indexs = np.asarray(range(ix[0]))
NVB_Indexs = np.asarray(range(ix[0], ix[0] + ix[1]))
#np.random.shuffle(NOTNVB_Indexs)
#np.random.shuffle(NVB_Indexs)
for j, index in enumerate(NOTNVB_Indexs):
if j == 6:
break
img, label = DataSet[index]
cam, pred = CAM(TrainedModel, img, device)
CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
i += 1
for j, index in enumerate(NVB_Indexs):
if j == 6:
break
img, label = DataSet[index]
cam, pred = CAM(TrainedModel, img, device)
CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
i += 1
else:
for img, label in tqdm(DataSet):
cam, pred = CAM(TrainedModel, img, device)
CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
i += 1
fig.suptitle(title)
def Calc_Eval(TrainModel:M.RARP_NVB_Model):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
with torch.no_grad():
for data, label in tqdm(testDataset):
data = data.to(device).float().unsqueeze(0)
pred = torch.sigmoid(TrainModel(data)[0].cpu())
Predictions.append(pred)
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.tensor(Labels).int()
print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary')(Predictions, Labels)
precision = torchmetrics.Precision('binary')(Predictions, Labels)
recall = torchmetrics.Recall('binary')(Predictions, Labels)
cm = torchmetrics.ConfusionMatrix('binary')(Predictions, Labels)
auc = torchmetrics.AUROC('binary')(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary')(Predictions, Labels)
print(f"Val Accuracy: {acc:.4f}")
print(f"Val Precision: {precision:.4f}")
print(f"Val Recall: {recall:.4f}")
print(f"F1 Score: {f1Score:.4f}")
print(f"AUROC: {auc:.4f}")
print(testDataset.classes)
ax = sn.heatmap(cm, cmap="Greens", cbar=False, annot=True, annot_kws={"size": 18}, fmt='g', xticklabels=testDataset.classes, yticklabels=testDataset.classes)
ax.set_title(f"NVB Classifier Split #{args.Fold+1}")
ax.set_xlabel('Predict')
ax.set_ylabel('True')
plt.show()
def getModel (
Model_ID:int = 0,
InitWeight=torch.tensor([1,1]),
TypeLoss:M.TypeLossFunction = M.TypeLossFunction.CrossEntropy,
Ckpt_File:Path = None,
OptConfig:dict = {},
inputNeurons:int = 4,
mean:float = None,
std:float = None
):
Model = None
ModelCAM = None
match Model_ID:
case 0:
Model = M.RARP_NVB_ResNet50(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_ResNet50.load_from_checkpoint(ckpFile)
ModelCAM = None if Ckpt_File is None else M.RARP_NVB_ResNet50_CAM.load_from_checkpoint(ckpFile, strict=False)
case 1:
Model = M.RARP_NVB_ResNet18(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_ResNet18.load_from_checkpoint(ckpFile)
ModelCAM = None if Ckpt_File is None else M.RARP_NVB_ResNet18_CAM.load_from_checkpoint(ckpFile, strict=False)
case 2:
Model = M.RARP_NVB_MobileNetV2(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_MobileNetV2.load_from_checkpoint(ckpFile)
ModelCAM = None if Ckpt_File is None else M.RARP_NVB_MobileNetV2_CAM.load_from_checkpoint(ckpFile, strict=False)
case 3:
Model = M.RARP_NVB_EfficientNetV2(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_EfficientNetV2.load_from_checkpoint(ckpFile)
ModelCAM = None
case 4:
Model = M.RARP_NVB_Vit_b_16(InitWeight, TypeLoss) if Ckpt_File is None else M.RARP_NVB_Vit_b_16.load_from_checkpoint(ckpFile)
ModelCAM = None
case 5:
Model = M.RARP_NVB_DenseNet169(InitWeight, TypeLoss) if Ckpt_File is None else M.RARP_NVB_DenseNet169.load_from_checkpoint(ckpFile)
ModelCAM = None
case 6:
Model = M.RARP_NVB_ResNet50_V1(InitWeight, TypeLoss, schedulerLR=args.DyLr, InputNeurons=inputNeurons) if Ckpt_File is None else M.RARP_NVB_ResNet50_V1.load_from_checkpoint(ckpFile)
ModelCAM = None
case 7:
with open(f"train-EFold{args.Fold}.yaml") as file:
configFile = yaml.load(file, Loader=yaml.FullLoader)
Models = []
for models in configFile["models"]:
match models:
case "ResNet50":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_ResNet50.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_ResNet50(InitWeight, TypeLoss))
case "ResNet18":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_ResNet18.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_ResNet18(InitWeight, TypeLoss))
case "MobileNetV2":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_MobileNetV2.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_MobileNetV2(InitWeight, TypeLoss))
case "EfficientNetV2":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_EfficientNetV2.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_EfficientNetV2(InitWeight, TypeLoss))
case "DenseNet169":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_DenseNet169.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_DenseNet169(InitWeight, TypeLoss))
case _:
pass
print (f"{len(Models)} models Loaded")
Model = M.RARP_Ensemble(Models, InitWeight, TypeLoss, lr=1e-3)
ModelCAM = None
case 8:
Model = M.RARP_NVB_DaVit(InitWeight, TypeLoss) if Ckpt_File is None else M.RARP_NVB_DaVit.load_from_checkpoint(ckpFile)
ModelCAM = None
case 9:
Model = M.RARP_NVB_ResNet50_Deep(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_ResNet50_Deep.load_from_checkpoint(ckpFile)
ModelCAM = None
case 10:
Model = M.RARP_NVB_EfficientNetV2_Deep(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_EfficientNetV2_Deep.load_from_checkpoint(ckpFile)
ModelCAM = None
case 11:
Model = M.RARP_NVB_ResNet50_V2(InitWeight, TypeLoss, schedulerLR=args.DyLr, InputNeurons=inputNeurons) if Ckpt_File is None else M.RARP_NVB_ResNet50_V2.load_from_checkpoint(ckpFile)
ModelCAM = None
case 12:
Model = M.RARP_NVB_ResNet50_V3(InitWeight, TypeLoss, schedulerLR=args.DyLr, InputNeurons=inputNeurons) if Ckpt_File is None else M.RARP_NVB_ResNet50_V3.load_from_checkpoint(ckpFile)
ModelCAM = None
case 13:
Model = M.RARP_NVB_VAN(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_VAN.load_from_checkpoint(ckpFile)
ModelCAM = None if Ckpt_File is None else M.RARP_NVB_VAN_CAM.load_from_checkpoint(ckpFile, strict=False)
case 14:
TestModel = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
TestModel.fc = torch.nn.Linear(in_features=TestModel.fc.in_features, out_features=4)
Model = M.RARP_NVB_MultiClassModel(
None,
Model=TestModel,
Num_Classes=4,
) if Ckpt_File is None else M.RARP_NVB_MultiClassModel.load_from_checkpoint(ckpFile,Model=TestModel)
ModelCAM = None
case 15:
#if OptConfig.get("lr") is None:
# OptConfig = dict(
# lr = None, #0.00015278, #1.53E-4,
# L1 = None, #0.0000020505, #2.05E-6,
# Alpha = 0.45,
# Gamma = 0.5,
# Thao = 5#2
# )
Model = M.RARP_NVB_ResNet50_VAN(
"./log_ResNet50_X10/lightning_logs/version_8/checkpoints/RARP-epoch=5.ckpt",
#"./log_X10/lightning_logs/version_0/checkpoints/RARP-epoch=39.ckpt",
0.5,
InitWeight,
TypeLoss,
schedulerLR=args.DyLr,
PseudoLables=False,
HParameter=OptConfig
) if Ckpt_File is None else M.RARP_NVB_ResNet50_VAN.load_from_checkpoint(ckpFile)
ModelCAM = None
case 16:
Model = M.RARP_NVB_SSL_RestNet50_Deep("./log_ResNet50_Deep_X10/lightning_logs/version_3/checkpoints/RARP-epoch=9.ckpt", 0.5, InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_SSL_RestNet50_Deep.load_from_checkpoint(ckpFile)
#Model = M.RARP_NVB_SSL_RestNet50_Deep("./log_ResNet50DeepSSL_X10/lightning_logs/version_8/checkpoints/RARP-epoch=42.ckpt", 0.5, InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_SSL_RestNet50_Deep.load_from_checkpoint(ckpFile)
ModelCAM = None
case 17:
Model = M.RARP_NVB_DINO_RestNet50_Deep(
"./log_ResNet50_Deep_X10/lightning_logs/version_3/checkpoints/RARP-epoch=9.ckpt",
threshold=0.5,
TypeLoss=TypeLoss,
L1=1.31E-04,
) if Ckpt_File is None else M.RARP_NVB_DINO_RestNet50_Deep.load_from_checkpoint(ckpFile)
ModelCAM = None
case 18:
Model = M.RARP_NVB_DINO_VAN(
"./log_ResNet50_Deep_X10/lightning_logs/version_3/checkpoints/RARP-epoch=9.ckpt",
threshold=0.5,
TypeLoss=TypeLoss,
#L1=1.31E-04
) if Ckpt_File is None else M.RARP_NVB_DINO_VAN.load_from_checkpoint(ckpFile)
ModelCAM = None
case 19:
Model = M.RARP_NVB_RN50_VAN_V2(
#"./log_X10/lightning_logs/version_0/checkpoints/RARP-epoch=39.ckpt",
"./log_ResNet50_X10/lightning_logs/version_8/checkpoints/RARP-epoch=5.ckpt",
0.5,
InitWeight,
TypeLoss,
schedulerLR=args.DyLr,
PseudoLables=False,
HParameter=OptConfig, std=std, mean=mean
) if Ckpt_File is None else M.RARP_NVB_RN50_VAN_V2.load_from_checkpoint(ckpFile)
ModelCAM = None
case 20:
Model = M.RARP_NVB_DINO_MultiTask(
TypeLoss,
std=std,
mean=mean,
L1= 1.31E-04,
L2= 0
) if Ckpt_File is None else M.RARP_NVB_DINO_MultiTask.load_from_checkpoint(ckpFile)
ModelCAM = None
case _:
raise Exception("Model Not yet Implemented")
return (Model, ModelCAM)
def ViewImg(dataset, std, mean):
_, axis = plt.subplots(2, 2, figsize=(9, 9))
for i in range(2):
for j in range(2):
random_index = np.random.randint(0, 44)
img, label = dataset[random_index]
img, _ = img
img = img.numpy().transpose((1, 2, 0))
img = np.clip((std * img + mean) / 255, 0, 1)
img = img[...,::-1].copy()
axis[i][j].imshow(img)
axis[i][j].set_title(f"Label:{label}")
def ViewImgDINO(dataset, std, mean):
_, axis = plt.subplots(4, 7, figsize=(25, 25))
for i in range(4):
random_index = np.random.randint(0, len(dataset.targets))
imgCrops, label = dataset[random_index]
for j, img in enumerate(imgCrops):
img = img.numpy().transpose((1, 2, 0))
img = np.clip((std * img + mean) / 255, 0, 1)
img = img[...,::-1].copy()
axis[i][j].imshow(img)
axis[i][j].set_title(f"Label:{label}")
axis[i][j].set_axis_off()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
parser.add_argument("--Fold", type=int, default=0)
parser.add_argument("--Workers", type=int, default=0)
parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
parser.add_argument("--Model", type=int, default=0, help="0 = ResNet18, 1 = ResNet50")
parser.add_argument("-lv", "--Log_version", type=int)
parser.add_argument("-le", "--Log_epoch", type=int)
parser.add_argument("-ls", "--Log_step", type=int)
parser.add_argument("--Remove_Blackbar", type=bool, default=True)
parser.add_argument("--BGR2RGB", type=bool, default=False)
parser.add_argument("--CAM", type=bool, default=False)
parser.add_argument("-roi", "--Use_ROI_Dataset", type=int, default=0)
parser.add_argument("-s", "--imgSlice_pct", type=float, default=None)
parser.add_argument("-ns", "--Num_Slices", type=int, default=4)
parser.add_argument("-wl", "--Wloss",type=bool, default=False)
parser.add_argument("--sClass",type=int, default=None)
parser.add_argument("-tl", "--TypeLoss", type=int, default=0)
parser.add_argument("-cs", "--ColorSpace", type=int, default=None)
parser.add_argument("--JIndex", type=bool, default=False)
parser.add_argument("-me", "--maxEpochs", type=int, default=None)
parser.add_argument("-lc", "--LoadChkpt", type=str, default=None)
parser.add_argument("--AddTestSet", type=str, default=None)
parser.add_argument("--Metadata", type=str, default=None)
parser.add_argument("--DyLr", type=bool, default=False)
parser.add_argument("-lr", type=float, default=1e-4)
parser.add_argument("--ExtraNeurons", type=int, default=4)
args = parser.parse_args()
if args.CAM and args.Phase == "train":
raise Exception("Clases Activation Clases only in eval o eval_all")
match args.Use_ROI_Dataset:
case 0:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_main",
FoldSeed=505,
createFile=True,
SavePath="./DataSet",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 1:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSetCrop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 2:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Crop1",
FoldSeed=505,
createFile=True,
SavePath="./DataSetCrop1",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 3:
YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_main",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_YOLO",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace,
ROI_Yolo=YoloModel
)
cropSize = 256
case 4:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_big",
FoldSeed=505,
createFile=True,
SavePath="./DatasetBig",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 6:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Full",
FoldSeed=505,
createFile=True,
SavePath="./DatasetFull",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 7:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_smallBalaced",
FoldSeed=505,
createFile=True,
SavePath="./DatasetSmallBalanced",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 8:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_big_2",
FoldSeed=505,
createFile=True,
SavePath="./DatasetBig2",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 9:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando",
FoldSeed=505,
createFile=True,
SavePath="./DataSetAndo",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 10:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_smallBalacedCrop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_SB_Crop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 11:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_AndoCrop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_SB_Ando_Crop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 12:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All_Crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Ando_Crop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 13:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Ando_All",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 14:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_AllNewLabels",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_New_labels",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 15:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_ChageLabels",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_C_L",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 16:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_ChageLabels_crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Ando_Crop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 17:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All_no20",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Ando_All_20",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 5:
YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_big",
FoldSeed=505,
createFile=True,
SavePath="./DatasetBig_YOLO",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace,
ROI_Yolo=YoloModel
)
cropSize = 256
case 20:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_big_Multiclass",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Multiclass",
Fold=5,
Num_Labels=4,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
Dataset.CreateFolds()
Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
setup_seed(2023)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batchSize = 8 #17 #8, 32
numWorkers = args.Workers
InitResize = (256,256)
ImgResize = (224, 224)
checkPtCallback = callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max')
ckpLossBest = callbk.ModelCheckpoint(monitor="val_loss", filename="RARP-{epoch}-{val_loss:.2f}", save_top_k=2, mode='min')
traintransform = torch.nn.Sequential(
transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224), #AQUI se cambio 2024/05/10
#transforms.RandomCrop(ImgResize),
transforms.RandomAffine(
degrees=(-5, 5), scale=(0.9, 1.1),
fill=5
),
transforms.RandomHorizontalFlip(1.0),
transforms.Normalize(Dataset.mean, Dataset.std),
).to(device)
traintransformT2 = torch.nn.Sequential(
transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomCrop(224),
transforms.RandomErasing(0.8, value="random"),
transforms.RandomAffine(degrees=(-45, 45), scale=(0.8, 1.2), fill=5),
transforms.GaussianBlur(5),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
valtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
testtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
TrainDINOTransforms = Loaders.RARP_DINO_Augmentation(
GloblaCropsScale = (0.4, 1),
LocalCropsScale = (0.05, 0.4),
NumLocalCrops = 4,
Size = 224,
device = device,
mean = Dataset.mean,
std = Dataset.std,
Tranform_0 = testtransform if args.Model == 20 else None
)
rootFile = Dataset.CVS_File.parent.parent/f"fold_{args.Fold}"
Add_Test_DataLoader = None
traintransform = TrainDINOTransforms if args.Model in (17, 18, 20) else traintransform
if args.AddTestSet is not None and args.Metadata is None:
Add_TestDataset = torchvision.datasets.DatasetFolder(
str (Path(args.AddTestSet)/f"fold_{args.Fold}"/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transform=testtransform
)
Add_Test_DataLoader = DataLoader(
Add_TestDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True
)
if args.Metadata is None:
if args.Model in (15, 16, 19):
trainDataset = Loaders.RARP_DatasetFolder_DobleTransform(
str (rootFile/"train"),
loader=defs.load_file_tensor,
extensions="npy",
transformT1=traintransform,
transformT2=traintransformT2,
passOriginal= testtransform if args.Model == 19 else None
)
valDataset = Loaders.RARP_DatasetFolder_DobleTransform(
str (rootFile/"val"),
loader=defs.load_file_tensor,
extensions="npy",
transformT1=valtransform,
passOriginal= testtransform if args.Model == 19 else None
)
testDataset = Loaders.RARP_DatasetFolder_DobleTransform(
str (rootFile/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transformT1=testtransform,
passOriginal= testtransform if args.Model == 19 else None
)
else:
trainDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"train"),
loader=defs.load_file_tensor,
extensions="npy",
transform=traintransform
)
valDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"val"),
loader=defs.load_file_tensor,
extensions="npy",
transform=valtransform
)
testDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transform=testtransform
)
else:
DumpCSV = pd.read_csv(Dataset.CVS_File)
Extradata = pd.read_excel(Path(args.Metadata))
Extradata["name"] = Extradata["列1"].astype(str) + ".tiff"
Extradata = Extradata.drop(columns=["列1"])
DumpCSV["raw_name"] = "Img0-" + DumpCSV["id"].astype(str) + ".npy"
DumpCSV = DumpCSV.drop(columns=["id", "path", "mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3"])
NewData = pd.merge(Extradata, DumpCSV, on="name")
trainDataset = Loaders.RARP_DatasetFolder_ExtraData(
str (rootFile/"train"),
loader=defs.load_file_tensor,
Extra_Data=NewData,
Extra_Data_leg = args.ExtraNeurons,
extensions="npy",
transform=traintransform
)
valDataset = Loaders.RARP_DatasetFolder_ExtraData(
str (rootFile/"val"),
loader=defs.load_file_tensor,
Extra_Data=NewData,
Extra_Data_leg = args.ExtraNeurons,
extensions="npy",
transform=valtransform
)
testDataset = Loaders.RARP_DatasetFolder_ExtraData(
str (rootFile/"test"),
loader=defs.load_file_tensor,
Extra_Data=NewData,
Extra_Data_leg = args.ExtraNeurons,
extensions="npy",
transform=testtransform
)
Train_DataLoader = DataLoader(
trainDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=True,
drop_last=True,
pin_memory=True,
persistent_workers=numWorkers>0
)
Val_DataLoader = DataLoader(
valDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
Test_DataLoader = DataLoader(
testDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
if args.CAM:
testCAMDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transform=torch.nn.Sequential(
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
)
TestCAM_DataLoader = DataLoader(
testCAMDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True
)
print(f"Currtent Fold Splits {Dataset.Splits[args.Fold]}")
print(f"Unique Values in sets")
info = np.unique(trainDataset.targets, return_counts=True), np.unique(valDataset.targets, return_counts=True), np.unique(testDataset.targets, return_counts=True)
print(info)
neg = 0
pos = 0
for i in info:
neg += i[1][0]
pos += i[1][1]
total = neg + pos
factor = 2 if args.TypeLoss == 1 else 1
InitWeight = torch.tensor([total/(neg * factor), total/(pos * factor)]).to(device) if args.Wloss else None
if InitWeight is not None:
print(f"Weights {InitWeight}")
TypeLoss = M.TypeLossFunction(args.TypeLoss)
Model, ModelCAM = getModel(
args.Model,
InitWeight,
TypeLoss,
mean=Dataset.mean,
std=Dataset.std
)
NameModel = type(Model).__name__
print(f"Model Used: {NameModel}")
LogFileName = f"{args.Log_Name}"
MaxEpochs = 150
if args.Model == 4:
MaxEpochs = 150
if args.maxEpochs is not None:
MaxEpochs = args.maxEpochs
#warnings.simplefilter("ignore")
match(args.Phase):
case "train":
trainer = L.Trainer(
deterministic=True,
#gradient_clip_val=2.0,
accelerator='gpu',
devices=1,
logger=CSVLogger(save_dir=f"./{LogFileName}", name="Tune") if args.Phase == "tune" else TensorBoardLogger(save_dir=f"./{LogFileName}"),
log_every_n_steps=5,
#callbacks=[checkPtCallback, StepDropout(5, base_drop_rate=0.2, gamma=0.05, ascending=True)],#if args.Model == 4 else checkPtCallback,
callbacks=[checkPtCallback, callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)], #ckpLossBest, ],
max_epochs=MaxEpochs,
)
print("Train Phase")
trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader, ckpt_path=args.LoadChkpt)
trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")
case "tune":
print("Tuning")
pruner = optuna.pruners.SuccessiveHalvingPruner()#MedianPruner()
sampler = optuna.samplers.GPSampler(seed=2023) if args.Log_step == 1 else optuna.samplers.TPESampler(seed=2023)
study = optuna.create_study(direction="maximize", pruner=pruner, sampler=sampler)
study.optimize(objective, n_trials=100)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(f" Paramas: ")
for key, val in trial.params.items():
print(f" {key}: {val}")
case "eval_all":
print("Evaluation Phase")
rows = []
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
for ckpFile in pathCkptFile.glob("*.ckpt"):
print(ckpFile.name)
Model, ModelCAM = getModel(args.Model, InitWeight, TypeLoss, ckpFile)
#ViewImgDINO(trainDataset, Dataset.std, Dataset.mean)
if isinstance(Model, (M.RARP_NVB_MultiClassModel, M.RARP_NVB_DINO_MultiTask_v2)):
temp = Calc_EvalMulticlass_table(Model, Test_DataLoader, False, ckpFile.name, NumClasses=4)
else:
temp = Calc_Eval_table(
Model,
Test_DataLoader,
args.JIndex,
ckpFile.name,
Add_TestDataset=Add_Test_DataLoader,
extraData=(args.Metadata is not None),
PseudoLabel=False,
dataSetInfo=Dataset
)
rows += temp
if args.CAM and ModelCAM is not None:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
print("CAM")
ShowCAM(ModelCAM, testCAMDataset, Dataset.mean, Dataset.std, ckpFile.name)
df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])
df.style.highlight_max(color="red", axis=0)
print(df)
plt.show()
case _:
print("Evaluation Phase")
trainLog = [args.Log_version, args.Log_epoch, args.Log_step]
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{trainLog[0]}/checkpoints/epoch={trainLog[1]}-step={trainLog[2]}.ckpt")
Calc_Eval(Model.load_from_checkpoint(pathCkptFile))
if args.CAM:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ShowCAM(ModelCAM.load_from_checkpoint(pathCkptFile, strict=False), testCAMDataset, Dataset.mean, Dataset.std, pathCkptFile.name)
plt.show()