import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import torchmetrics
import lightning as L
from lightning.pytorch import seed_everything
from lightning.pytorch.tuner import Tuner
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import Loaders
import defs
import argparse
import seaborn as sn
import Models as M
import pandas as pd
import warnings
from ultralytics import YOLO
import yaml
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import japanize_matplotlib
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True
def objective(trail: optuna.trial.Trial) -> float:
lr = trail.suggest_float("lr", 1e-4, 1e-3, log=True)
l1 = trail.suggest_float("L1", 1e-6, 1e-3, log=True)
Alpha = trail.suggest_float("W_Apha", 0, 1, step=0.05)
Thao = trail.suggest_float("Thao_KD", 1, 7, step=0.25)
Trainer_OP = L.Trainer(
logger=TensorBoardLogger(save_dir=f"./{LogFileName}", name="Tune"),
#enable_checkpointing=False,
max_epochs=MaxEpochs,
accelerator="auto",
log_every_n_steps=5,
devices=1,
callbacks=[PyTorchLightningPruningCallback(trail, monitor="val_acc"), callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=2, mode='max')],
)
hyperparameters = dict(
lr = lr,
L1 = l1,
Alpha = Alpha,
Beta = 1 - Alpha,
Thao = Thao
)
ModelOP, _ = getModel(
args.Model,
InitWeight,
TypeLoss,
OptConfig=hyperparameters
)
Trainer_OP.logger.log_hyperparams(hyperparameters)
Trainer_OP.fit(ModelOP, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
Trainer_OP.test(ModelOP, dataloaders=Test_DataLoader, ckpt_path="best")
return Trainer_OP.callback_metrics["test_acc"].item()
def Calc_Eval_table_New(TrainModel:M.RARP_NVB_Model):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
if isinstance(TrainModel, M.RARP_NVB_Model_test2):
with torch.no_grad():
for img, label in tqdm(iter(Test_DataLoader)):
img = img.float().to(device)
label = label.to(device)
pred = TrainModel(img)
Predictions.append(torch.softmax(pred, dim=1))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy("multiclass", num_classes=2).to(device)(Predictions, Labels)
precision = torchmetrics.Precision("multiclass", num_classes=2).to(device)(Predictions, Labels)
recall = torchmetrics.Recall("multiclass", num_classes=2).to(device)(Predictions, Labels)
auc = torchmetrics.AUROC("multiclass", num_classes=2).to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score("multiclass", num_classes=2).to(device)(Predictions, Labels)
return [acc.item(), precision.item(), recall.item(), f1Score.item(), auc.item()]
with torch.no_grad():
for img, label in tqdm(iter(Test_DataLoader)):
img = img.float().to(device)
label = label.float().to(device)
pred = TrainModel(img)
Predictions.append(torch.sigmoid(pred.squeeze(1)))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
#cm = torchmetrics.ConfusionMatrix('binary')(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
return [acc.item(), precision.item(), recall.item(), f1Score.item(), auc.item()]
def encode_2labels_4classes (x:torch, th:float = 0.5):
if x.dtype == torch.float:
x = (x > 0.5) *1
r, l = x
if r == 0 and l == 0:
return 0
elif r == 1 and l == 0:
return 1
elif r == 0 and l == 1:
return 2
elif r == 1 and l == 1:
return 3
else:
return -1
def Calc_EvalMulticlass_table(TrainModel:M.RARP_NVB_Model,TestDataLoadre:DataLoader, Youden=False, modelName="", NumClasses:int=2, Num_Label:int=None):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
with torch.no_grad():
for data, label in tqdm(iter(TestDataLoadre)):
data = data.float().to(device)
label = label.to(device)
if isinstance(TrainModel, M.RARP_NVB_DINO_MultiTask):
pred, _, _ = TrainModel(data)
NumClasses = 4 if Num_Label is None else None
else:
pred = TrainModel(data)
Predictions.append(torch.softmax(pred, dim=1) if Num_Label is None else torch.sigmoid(pred))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
precision = torchmetrics.Precision("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
recall = torchmetrics.Recall("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
auc = torchmetrics.AUROC("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", f"*{modelName}*"] #.item():.4f
]
if Num_Label is not None:
single_labels_pred = [encode_2labels_4classes(p.cpu()) for p in Predictions]
single_labels_true = [encode_2labels_4classes(p.cpu()) for p in Labels]
labels_names = ["なし", "右", "左", "右+左"]
cm = confusion_matrix(single_labels_true, single_labels_pred, labels=[0,1,2,3])
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels_names)
disp.plot()
#cm2 = torchmetrics.ConfusionMatrix("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier {modelName}")
if Youden:
for i in range(2):
aucCurve = torchmetrics.ROC("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
return table
def Calc_Eval_table(
TrainModel:M.RARP_NVB_Model,
TestDataLoadre:DataLoader,
Youden=False, modelName="",
Add_TestDataset:DataLoader=None,
extraData:bool=False,
PseudoLabel:bool=True,
dataSetInfo:Loaders.RARP_DatasetCreator = None
):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
PseudoLabelTest = None
ban_ExtraImage = isinstance(testDataset, Loaders.RARP_DatasetFolder_DobleTransform)
with torch.no_grad():
for data, label in tqdm(iter(TestDataLoadre)):
if extraData:
img, extra = data
img = img.float().to(device)
extra = extra.float().to(device)
data = (img, extra)
elif ban_ExtraImage:
if len(data) == 3:
TData, Sdata, OData = data
data = (TData.float().to(device), Sdata.float().to(device), OData.float().to(device))
else:
TData, Sdata = data
data = (TData.float().to(device), Sdata.float().to(device))
else:
data = data.float().to(device)
label = label.to(device)
if isinstance(TrainModel, M.RARP_NVB_ResNet50_VAN):
pred, Plabel, _ = TrainModel(data)
pred = pred.flatten()
label = Plabel.int() if PseudoLabel else label
elif isinstance(TrainModel, (M.RARP_NVB_RN50_VAN_V2, M.RARP_NVB_DINO_MultiTask)):
pred, features, new_img = TrainModel(data)
pred = pred.flatten()
#_, axis = plt.subplots(2, 2, figsize=(9, 9))
#for i in range(2):
# for j in range(2):
# random_index = np.random.randint(0, len(new_img))
# img = new_img[random_index].cpu()
# img = img.numpy().transpose((1, 2, 0))
# img = np.clip((dataSetInfo.std * img + dataSetInfo.mean) / 255, 0, 1)
# img = img[...,::-1].copy()
# axis[i][j].imshow(img)
elif isinstance(TrainModel, M.RARP_NVB_DINO_RestNet50_Deep):
DK_PredLabels, _ = TrainModel(data)
pred, Plabel, _ = DK_PredLabels
label = Plabel.int() if PseudoLabel else label
else:
pred = TrainModel(data).flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
if Add_TestDataset is not None:
with torch.no_grad():
for data, label in tqdm(iter(Add_TestDataset)):
data = data.float().to(device)
label = label.to(device)
pred = TrainModel(data).flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
]
if Youden:
for i in range(2):
aucCurve = torchmetrics.ROC("binary").to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
return table
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
def CAM(model:M.RARP_NVB_Model, img:torch.Tensor, device):
with torch.no_grad():
img = img.to(device).float().unsqueeze(0)
if isinstance(model, M.RARP_NVB_VAN_CAM):
pred, feature = model(img, torch.tensor([0.0]).to(device))
else:
pred, feature = model(img)
_, c, h, w = feature.shape
feature = feature.reshape((c, h*w))
if isinstance(model, (M.RARP_NVB_ResNet18_CAM, M.RARP_NVB_ResNet50_CAM)):
wParams = list(model.model.fc.parameters())
elif isinstance(model, (M.RARP_NVB_MobileNetV2_CAM, M.RARP_NVB_EfficientNetV2_CAM)):
wParams = list(model.model.classifier.parameters())
elif isinstance(model, M.RARP_NVB_VAN_CAM):
wParams = list(model.model.head.parameters())
else:
raise "Cam Not Implemented"
pesos = wParams[0].detach()
cam = torch.matmul(pesos, feature)
cam = cam - torch.min(cam)
cam_img = cam / torch.max(cam)
cam_img = cam_img.reshape(h, w).cpu()
return cam_img, torch.sigmoid(pred)
def CAMVisualizer(img, heatmap, pred, label, mean, std, ax, row):
img = img.numpy().transpose((1, 2, 0))
heatmap = transforms.functional.resize(heatmap.unsqueeze(0), (img.shape[0], img.shape[1]), antialias=True)[0]
img = np.clip((std * img + mean) / 255, 0, 1)
img = img[...,::-1].copy()
col = 0
if row > 3:
col = 2
if row > 7:
col = 4
ax[row % 4][col + 0].imshow(img)
ax[row % 4][col + 0].axis('off')
ax[row % 4][col + 1].imshow(img)
ax[row % 4][col + 1].imshow(heatmap, alpha=0.5, cmap="jet")
ax[row % 4][col + 1].axis('off')
ax[row % 4][col + 1].set_title(f"Pred.: {pred.item():.4f}; Label: {label}")
#plt.title()
def ShowCAM(TrainedModel:M.RARP_NVB_Model, DataSet, mean, std, title=""):
TrainedModel.to(device)
TrainedModel.eval()
i = 0
params = {
"left":0,
"bottom":0.01,
"right":1,
"top":0.914,
"wspace":0,
"hspace":0.164
}
fig, axis = plt.subplots(4, 6, gridspec_kw=params)
with torch.no_grad():
if len(DataSet) > 12:
ix = np.unique(DataSet.targets, return_counts=True)[1]
NOTNVB_Indexs = np.asarray(range(ix[0]))
NVB_Indexs = np.asarray(range(ix[0], ix[0] + ix[1]))
#np.random.shuffle(NOTNVB_Indexs)
#np.random.shuffle(NVB_Indexs)
for j, index in enumerate(NOTNVB_Indexs):
if j == 6:
break
img, label = DataSet[index]
cam, pred = CAM(TrainedModel, img, device)
CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
i += 1
for j, index in enumerate(NVB_Indexs):
if j == 6:
break
img, label = DataSet[index]
cam, pred = CAM(TrainedModel, img, device)
CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
i += 1
else:
for img, label in tqdm(DataSet):
cam, pred = CAM(TrainedModel, img, device)
CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
i += 1
fig.suptitle(title)
def Calc_Eval(TrainModel:M.RARP_NVB_Model):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
with torch.no_grad():
for data, label in tqdm(testDataset):
data = data.to(device).float().unsqueeze(0)
pred = torch.sigmoid(TrainModel(data)[0].cpu())
Predictions.append(pred)
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.tensor(Labels).int()
print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary')(Predictions, Labels)
precision = torchmetrics.Precision('binary')(Predictions, Labels)
recall = torchmetrics.Recall('binary')(Predictions, Labels)
cm = torchmetrics.ConfusionMatrix('binary')(Predictions, Labels)
auc = torchmetrics.AUROC('binary')(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary')(Predictions, Labels)
print(f"Val Accuracy: {acc:.4f}")
print(f"Val Precision: {precision:.4f}")
print(f"Val Recall: {recall:.4f}")
print(f"F1 Score: {f1Score:.4f}")
print(f"AUROC: {auc:.4f}")
print(testDataset.classes)
ax = sn.heatmap(cm, cmap="Greens", cbar=False, annot=True, annot_kws={"size": 18}, fmt='g', xticklabels=testDataset.classes, yticklabels=testDataset.classes)
ax.set_title(f"NVB Classifier Split #{args.Fold+1}")
ax.set_xlabel('Predict')
ax.set_ylabel('True')
plt.show()
def getModel (
Model_ID:int = 0,
InitWeight=torch.tensor([1,1]),
TypeLoss:M.TypeLossFunction = M.TypeLossFunction.CrossEntropy,
Ckpt_File:Path = None,
OptConfig:dict = {},
inputNeurons:int = 4,
mean:float = None,
std:float = None
):
Model = None
ModelCAM = None
match Model_ID:
case 0:
Model = M.RARP_NVB_ResNet50(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_ResNet50.load_from_checkpoint(ckpFile)
ModelCAM = None if Ckpt_File is None else M.RARP_NVB_ResNet50_CAM.load_from_checkpoint(ckpFile, strict=False)
case 1:
Model = M.RARP_NVB_ResNet18(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_ResNet18.load_from_checkpoint(ckpFile)
ModelCAM = None if Ckpt_File is None else M.RARP_NVB_ResNet18_CAM.load_from_checkpoint(ckpFile, strict=False)
case 2:
Model = M.RARP_NVB_MobileNetV2(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_MobileNetV2.load_from_checkpoint(ckpFile)
ModelCAM = None if Ckpt_File is None else M.RARP_NVB_MobileNetV2_CAM.load_from_checkpoint(ckpFile, strict=False)
case 3:
Model = M.RARP_NVB_EfficientNetV2(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_EfficientNetV2.load_from_checkpoint(ckpFile)
ModelCAM = None
case 4:
Model = M.RARP_NVB_Vit_b_16(InitWeight, TypeLoss) if Ckpt_File is None else M.RARP_NVB_Vit_b_16.load_from_checkpoint(ckpFile)
ModelCAM = None
case 5:
Model = M.RARP_NVB_DenseNet169(InitWeight, TypeLoss) if Ckpt_File is None else M.RARP_NVB_DenseNet169.load_from_checkpoint(ckpFile)
ModelCAM = None
case 6:
Model = M.RARP_NVB_ResNet50_V1(InitWeight, TypeLoss, schedulerLR=args.DyLr, InputNeurons=inputNeurons) if Ckpt_File is None else M.RARP_NVB_ResNet50_V1.load_from_checkpoint(ckpFile)
ModelCAM = None
case 7:
with open(f"train-EFold{args.Fold}.yaml") as file:
configFile = yaml.load(file, Loader=yaml.FullLoader)
Models = []
for models in configFile["models"]:
match models:
case "ResNet50":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_ResNet50.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_ResNet50(InitWeight, TypeLoss))
case "ResNet18":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_ResNet18.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_ResNet18(InitWeight, TypeLoss))
case "MobileNetV2":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_MobileNetV2.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_MobileNetV2(InitWeight, TypeLoss))
case "EfficientNetV2":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_EfficientNetV2.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_EfficientNetV2(InitWeight, TypeLoss))
case "DenseNet169":
if configFile["models"][models] is not None:
for pathckpt in configFile["models"][models]:
#Models.append(M.RARP_NVB_DenseNet169.load_from_checkpoint(Path(pathckpt), strict=False))
Models.append(M.RARP_NVB_DenseNet169(InitWeight, TypeLoss))
case _:
pass
print (f"{len(Models)} models Loaded")
Model = M.RARP_Ensemble(Models, InitWeight, TypeLoss, lr=1e-3)
ModelCAM = None
case 8:
Model = M.RARP_NVB_DaVit(InitWeight, TypeLoss) if Ckpt_File is None else M.RARP_NVB_DaVit.load_from_checkpoint(ckpFile)
ModelCAM = None
case 9:
Model = M.RARP_NVB_ResNet50_Deep(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_ResNet50_Deep.load_from_checkpoint(ckpFile)
ModelCAM = None
case 10:
Model = M.RARP_NVB_EfficientNetV2_Deep(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_EfficientNetV2_Deep.load_from_checkpoint(ckpFile)
ModelCAM = None
case 11:
Model = M.RARP_NVB_ResNet50_V2(InitWeight, TypeLoss, schedulerLR=args.DyLr, InputNeurons=inputNeurons) if Ckpt_File is None else M.RARP_NVB_ResNet50_V2.load_from_checkpoint(ckpFile)
ModelCAM = None
case 12:
Model = M.RARP_NVB_ResNet50_V3(InitWeight, TypeLoss, schedulerLR=args.DyLr, InputNeurons=inputNeurons) if Ckpt_File is None else M.RARP_NVB_ResNet50_V3.load_from_checkpoint(ckpFile)
ModelCAM = None
case 13:
Model = M.RARP_NVB_VAN(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_VAN.load_from_checkpoint(ckpFile)
ModelCAM = None if Ckpt_File is None else M.RARP_NVB_VAN_CAM.load_from_checkpoint(ckpFile, strict=False)
case 14:
TestModel = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
TestModel.fc = torch.nn.Linear(in_features=TestModel.fc.in_features, out_features=4)
Model = M.RARP_NVB_MultiClassModel(
None,
Model=TestModel,
Num_Classes=4,
) if Ckpt_File is None else M.RARP_NVB_MultiClassModel.load_from_checkpoint(ckpFile,Model=TestModel)
ModelCAM = None
case 15:
#if OptConfig.get("lr") is None:
# OptConfig = dict(
# lr = None, #0.00015278, #1.53E-4,
# L1 = None, #0.0000020505, #2.05E-6,
# Alpha = 0.45,
# Gamma = 0.5,
# Thao = 5#2
# )
Model = M.RARP_NVB_ResNet50_VAN(
"./log_ResNet50_X10/lightning_logs/version_8/checkpoints/RARP-epoch=5.ckpt",
#"./log_X10/lightning_logs/version_0/checkpoints/RARP-epoch=39.ckpt",
0.5,
InitWeight,
TypeLoss,
schedulerLR=args.DyLr,
PseudoLables=False,
HParameter=OptConfig
) if Ckpt_File is None else M.RARP_NVB_ResNet50_VAN.load_from_checkpoint(ckpFile)
ModelCAM = None
case 16:
Model = M.RARP_NVB_SSL_RestNet50_Deep("./log_ResNet50_Deep_X10/lightning_logs/version_3/checkpoints/RARP-epoch=9.ckpt", 0.5, InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_SSL_RestNet50_Deep.load_from_checkpoint(ckpFile)
#Model = M.RARP_NVB_SSL_RestNet50_Deep("./log_ResNet50DeepSSL_X10/lightning_logs/version_8/checkpoints/RARP-epoch=42.ckpt", 0.5, InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_SSL_RestNet50_Deep.load_from_checkpoint(ckpFile)
ModelCAM = None
case 17:
Model = M.RARP_NVB_DINO_RestNet50_Deep(
"./log_ResNet50_Deep_X10/lightning_logs/version_3/checkpoints/RARP-epoch=9.ckpt",
threshold=0.5,
TypeLoss=TypeLoss,
L1=1.31E-04,
) if Ckpt_File is None else M.RARP_NVB_DINO_RestNet50_Deep.load_from_checkpoint(ckpFile)
ModelCAM = None
case 18:
Model = M.RARP_NVB_DINO_VAN(
"./log_ResNet50_Deep_X10/lightning_logs/version_3/checkpoints/RARP-epoch=9.ckpt",
threshold=0.5,
TypeLoss=TypeLoss,
#L1=1.31E-04
) if Ckpt_File is None else M.RARP_NVB_DINO_VAN.load_from_checkpoint(ckpFile)
ModelCAM = None
case 19:
Model = M.RARP_NVB_RN50_VAN_V2(
#"./log_X10/lightning_logs/version_0/checkpoints/RARP-epoch=39.ckpt",
"./log_ResNet50_X10/lightning_logs/version_8/checkpoints/RARP-epoch=5.ckpt",
0.5,
InitWeight,
TypeLoss,
schedulerLR=args.DyLr,
PseudoLables=False,
HParameter=OptConfig, std=std, mean=mean
) if Ckpt_File is None else M.RARP_NVB_RN50_VAN_V2.load_from_checkpoint(ckpFile)
ModelCAM = None
case 20:
Model = M.RARP_NVB_DINO_MultiTask_Unet(
TypeLoss,
std=std,
mean=mean,
L1= 1.31E-04,
L2= 0,
SoftAdptAlgo=0,
) if Ckpt_File is None else M.RARP_NVB_DINO_MultiTask_Unet.load_from_checkpoint(ckpFile)
ModelCAM = None
case 21:
Model = None if Ckpt_File is None else M.RARP_Hybrid_TS_LR(ckpFile, masked=True)
ModelCAM = None
case _:
raise Exception("Model Not yet Implemented")
return (Model, ModelCAM)
def ViewImg(dataset, std, mean):
_, axis = plt.subplots(2, 2, figsize=(9, 9))
for i in range(2):
for j in range(2):
random_index = np.random.randint(0, 44)
img, label = dataset[random_index]
img, _ = img
img = img.numpy().transpose((1, 2, 0))
img = np.clip((std * img + mean) / 255, 0, 1)
img = img[...,::-1].copy()
axis[i][j].imshow(img)
axis[i][j].set_title(f"Label:{label}")
def ViewImgDINO(dataset, std, mean):
_, axis = plt.subplots(4, 7, figsize=(25, 25))
for i in range(4):
random_index = np.random.randint(0, len(dataset.targets))
imgCrops, label = dataset[random_index]
for j, img in enumerate(imgCrops):
img = img.numpy().transpose((1, 2, 0))
img = np.clip((std * img + mean) / 255, 0, 1)
img = img[...,::-1].copy()
axis[i][j].imshow(img)
axis[i][j].set_title(f"Label:{label}")
axis[i][j].set_axis_off()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
parser.add_argument("--Fold", type=int, default=0)
parser.add_argument("--Workers", type=int, default=0)
parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
parser.add_argument("--Model", type=int, default=0, help="0 = ResNet18, 1 = ResNet50")
parser.add_argument("-lv", "--Log_version", type=int)
parser.add_argument("-le", "--Log_epoch", type=int)
parser.add_argument("-ls", "--Log_step", type=int)
parser.add_argument("--Remove_Blackbar", type=bool, default=True)
parser.add_argument("--BGR2RGB", type=bool, default=False)
parser.add_argument("--CAM", type=bool, default=False)
parser.add_argument("-roi", "--Use_ROI_Dataset", type=int, default=0)
parser.add_argument("-s", "--imgSlice_pct", type=float, default=None)
parser.add_argument("-ns", "--Num_Slices", type=int, default=4)
parser.add_argument("-wl", "--Wloss",type=bool, default=False)
parser.add_argument("--sClass",type=int, default=None)
parser.add_argument("-tl", "--TypeLoss", type=int, default=0)
parser.add_argument("-cs", "--ColorSpace", type=int, default=None)
parser.add_argument("--JIndex", type=bool, default=False)
parser.add_argument("-me", "--maxEpochs", type=int, default=None)
parser.add_argument("-lc", "--LoadChkpt", type=str, default=None)
parser.add_argument("--AddTestSet", type=str, default=None)
parser.add_argument("--Metadata", type=str, default=None)
parser.add_argument("--DyLr", type=bool, default=False)
parser.add_argument("-lr", type=float, default=1e-4)
parser.add_argument("--ExtraNeurons", type=int, default=4)
parser.add_argument("--ExtraLabels", type=str, default=None)
parser.add_argument("--Roi_Mask_Model", type=str, default=None)
args = parser.parse_args()
if args.CAM and args.Phase == "train":
raise Exception("Clases Activation Clases only in eval o eval_all")
match args.Use_ROI_Dataset:
case 0:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_main",
FoldSeed=505,
createFile=True,
SavePath="./DataSet",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 1:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSetCrop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 2:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Crop1",
FoldSeed=505,
createFile=True,
SavePath="./DataSetCrop1",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 3:
YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_main",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_YOLO",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace,
ROI_Yolo=YoloModel
)
cropSize = 256
case 4:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_big",
FoldSeed=505,
createFile=True,
SavePath="./DatasetBig",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 6:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Full",
FoldSeed=505,
createFile=True,
SavePath="./DatasetFull",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 7:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_smallBalaced",
FoldSeed=505,
createFile=True,
SavePath="./DatasetSmallBalanced",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 8:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_big_2",
FoldSeed=505,
createFile=True,
SavePath="./DatasetBig2",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 9:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando",
FoldSeed=505,
createFile=True,
SavePath="./DataSetAndo",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 10:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_smallBalacedCrop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_SB_Crop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 11:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_AndoCrop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_SB_Ando_Crop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 12:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All_Crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Ando_Crop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 13:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Ando_All",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 14:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_AllNewLabels",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_New_labels",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 15:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet",
FoldSeed=505,
createFile=True,
SavePath="./DataSet",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 16:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Crop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 256
case 17:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All_no20",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Ando_All_20",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 18:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All_no20Crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_AndoAll20_crop",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
case 19:
ROI_model = M.RARP_NVB_ROI_Mask_Unet.load_from_checkpoint(Path(args.Roi_Mask_Model))
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All_no20",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_AndoAll20_mask",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace,
ROI_Mask=ROI_model
)
cropSize = 720
case 21:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Kpts",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Kpts_FullSize",
Fold=5,
removeBlackBar=False,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace,
copyKpoints=True
)
cropSize = 720
case 5:
YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_big",
FoldSeed=505,
createFile=True,
SavePath="./DatasetBig_YOLO",
Fold=5,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace,
ROI_Yolo=YoloModel
)
cropSize = 256
case 20:
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_big_Multiclass",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_Multiclass",
Fold=5,
Num_Labels=4,
removeBlackBar=args.Remove_Blackbar,
RGBGama=args.BGR2RGB,
SegImage=args.imgSlice_pct,
Num_Img_Slices=args.Num_Slices,
SegmentClass=args.sClass,
colorSpace=args.ColorSpace
)
cropSize = 720
Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
Dataset.CreateFolds()
setup_seed(2023)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batchSize = 8 #17 #8, 32
numWorkers = args.Workers
InitResize = (256,256)
ImgResize = (224, 224)
checkPtCallback = callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max')
ckpLossBest = callbk.ModelCheckpoint(monitor="val_loss", filename="RARP-{epoch}-{val_loss:.2f}", save_top_k=2, mode='min')
traintransform = torch.nn.Sequential(
transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224), #AQUI se cambio 2024/05/10
#transforms.RandomCrop(ImgResize),
transforms.RandomAffine(
degrees=(-5, 5), scale=(0.9, 1.1),
fill=5
),
transforms.RandomHorizontalFlip(1.0),
transforms.Normalize(Dataset.mean, Dataset.std),
).to(device)
traintransformT2 = torch.nn.Sequential(
transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomCrop(224),
transforms.RandomErasing(0.8, value="random"),
transforms.RandomAffine(degrees=(-45, 45), scale=(0.8, 1.2), fill=5),
transforms.GaussianBlur(5),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
Roi_mask_transform = torch.nn.Sequential(
transforms.Resize((224, 224), antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
valtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device) if not args.Use_ROI_Dataset in [19,21] else Roi_mask_transform
testtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device) if not args.Use_ROI_Dataset in [19,21] else Roi_mask_transform
TrainDINOTransforms = Loaders.RARP_DINO_Augmentation(
GloblaCropsScale = (0.4, 1),
LocalCropsScale = (0.05, 0.4),
NumLocalCrops = 4,
Size = 224,
device = device,
mean = Dataset.mean,
std = Dataset.std,
Tranform_0 = testtransform if args.Model == 20 else None
)
rootFile = Dataset.CVS_File.parent.parent/f"fold_{args.Fold}"
Add_Test_DataLoader = None
traintransform = TrainDINOTransforms if args.Model in (17, 18, 20) else traintransform
if args.AddTestSet is not None and args.Metadata is None:
Add_TestDataset = torchvision.datasets.DatasetFolder(
str (Path(args.AddTestSet)/f"fold_{args.Fold}"/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transform=testtransform
)
Add_Test_DataLoader = DataLoader(
Add_TestDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True
)
if args.Metadata is None:
if args.Model in (15, 16, 19):
trainDataset = Loaders.RARP_DatasetFolder_DobleTransform(
str (rootFile/"train"),
loader=defs.load_file_tensor,
extensions="npy",
transformT1=traintransform,
transformT2=traintransformT2,
passOriginal= testtransform if args.Model == 19 else None
)
valDataset = Loaders.RARP_DatasetFolder_DobleTransform(
str (rootFile/"val"),
loader=defs.load_file_tensor,
extensions="npy",
transformT1=valtransform,
passOriginal= testtransform if args.Model == 19 else None
)
testDataset = Loaders.RARP_DatasetFolder_DobleTransform(
str (rootFile/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transformT1=testtransform,
passOriginal= testtransform if args.Model == 19 else None
)
elif args.Use_ROI_Dataset == 21:
trainDataset = Loaders.RARP_DatasetFolder_ROIExtractor_OnlyROI(
str (rootFile/"train"),
loader=defs.load_file,
extensions="npy",
transform=traintransform,
root_kpts= rootFile / "../../DataSet_Kpts"
)
valDataset = Loaders.RARP_DatasetFolder_ROIExtractor_OnlyROI(
str (rootFile/"val"),
loader=defs.load_file,
extensions="npy",
transform=valtransform,
root_kpts= rootFile / "../../DataSet_Kpts"
)
testDataset = Loaders.RARP_DatasetFolder_ROIExtractor_OnlyROI(
str (rootFile/"test"),
loader=defs.load_file,
extensions="npy",
transform=testtransform,
root_kpts= rootFile / "../../DataSet_Kpts"
)
else:
trainDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"train"),
loader=defs.load_file_tensor,
extensions="npy",
transform=traintransform
)
valDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"val"),
loader=defs.load_file_tensor,
extensions="npy",
transform=valtransform
)
testDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transform=testtransform
)
else:
DumpCSV = pd.read_csv(Dataset.CVS_File)
Extradata = pd.read_excel(Path(args.Metadata))
Extradata["name"] = Extradata["列1"].astype(str) + ".tiff"
Extradata = Extradata.drop(columns=["列1"])
DumpCSV["raw_name"] = "Img0-" + DumpCSV["id"].astype(str) + ".npy"
DumpCSV = DumpCSV.drop(columns=["id", "path", "mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3"])
NewData = pd.merge(Extradata, DumpCSV, on="name")
trainDataset = Loaders.RARP_DatasetFolder_ExtraData(
str (rootFile/"train"),
loader=defs.load_file_tensor,
Extra_Data=NewData,
Extra_Data_leg = args.ExtraNeurons,
extensions="npy",
transform=traintransform
)
valDataset = Loaders.RARP_DatasetFolder_ExtraData(
str (rootFile/"val"),
loader=defs.load_file_tensor,
Extra_Data=NewData,
Extra_Data_leg = args.ExtraNeurons,
extensions="npy",
transform=valtransform
)
testDataset = Loaders.RARP_DatasetFolder_ExtraData(
str (rootFile/"test"),
loader=defs.load_file_tensor,
Extra_Data=NewData,
Extra_Data_leg = args.ExtraNeurons,
extensions="npy",
transform=testtransform
)
if args.ExtraLabels is not None:
DumpCSV = pd.read_csv(Dataset.CVS_File)
Extradata = pd.read_excel(Path(args.ExtraLabels))
DumpCSV["raw_name"] = "Img0-" + DumpCSV["id"].astype(str) + ".npy"
DumpCSV = DumpCSV.drop(columns=["mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3", "path", "class", "label"])
outPut = pd.merge(Extradata, DumpCSV, on="name", how="right")
trainDataset = Loaders.RARP_DatasetFolder_ExtraLabel(
str (rootFile/"train"),
loader=defs.load_file_tensor,
Extra_Data=outPut,
extensions="npy",
transform=traintransform
)
valDataset = Loaders.RARP_DatasetFolder_ExtraLabel(
str (rootFile/"val"),
loader=defs.load_file_tensor,
Extra_Data=outPut,
extensions="npy",
transform=valtransform
)
testDataset = Loaders.RARP_DatasetFolder_ExtraLabel(
str (rootFile/"test"),
loader=defs.load_file_tensor,
Extra_Data=outPut,
extensions="npy",
transform=testtransform
)
Train_DataLoader = DataLoader(
trainDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=True,
drop_last=True,
pin_memory=True,
persistent_workers=numWorkers>0
)
Val_DataLoader = DataLoader(
valDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
Test_DataLoader = DataLoader(
testDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
if args.CAM:
testCAMDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transform=torch.nn.Sequential(
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
)
TestCAM_DataLoader = DataLoader(
testCAMDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True
)
print(f"Currtent Fold Splits {Dataset.Splits[args.Fold]}")
print(f"Unique Values in sets")
info = np.unique(trainDataset.targets, return_counts=True), np.unique(valDataset.targets, return_counts=True), np.unique(testDataset.targets, return_counts=True)
print(info)
neg = 0
pos = 0
for i in info:
neg += i[1][0]
pos += i[1][1]
total = neg + pos
factor = 2 if args.TypeLoss == 1 else 1
InitWeight = torch.tensor([total/(neg * factor), total/(pos * factor)]).to(device) if args.Wloss else None
if InitWeight is not None:
print(f"Weights {InitWeight}")
TypeLoss = M.TypeLossFunction(args.TypeLoss)
Model, ModelCAM = getModel(
args.Model,
InitWeight,
TypeLoss,
mean=Dataset.mean,
std=Dataset.std
)
NameModel = type(Model).__name__
print(f"Model Used: {NameModel}")
LogFileName = f"{args.Log_Name}"
MaxEpochs = 150
if args.Model == 4:
MaxEpochs = 150
if args.maxEpochs is not None:
MaxEpochs = args.maxEpochs
#warnings.simplefilter("ignore")
match(args.Phase):
case "train":
trainer = L.Trainer(
deterministic=True,
#gradient_clip_val=2.0,
accelerator='gpu',
devices=1,
logger=CSVLogger(save_dir=f"./{LogFileName}", name="Tune") if args.Phase == "tune" else TensorBoardLogger(save_dir=f"./{LogFileName}"),
log_every_n_steps=5,
#callbacks=[checkPtCallback, StepDropout(5, base_drop_rate=0.2, gamma=0.05, ascending=True)],#if args.Model == 4 else checkPtCallback,
callbacks=[checkPtCallback, callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)],
max_epochs=MaxEpochs,
)
print("Train Phase")
trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader, ckpt_path=args.LoadChkpt)
trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")
case "tune":
print("Tuning")
pruner = optuna.pruners.SuccessiveHalvingPruner()#MedianPruner()
sampler = optuna.samplers.GPSampler(seed=2023) if args.Log_step == 1 else optuna.samplers.TPESampler(seed=2023)
study = optuna.create_study(direction="maximize", pruner=pruner, sampler=sampler)
study.optimize(objective, n_trials=100)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(f" Paramas: ")
for key, val in trial.params.items():
print(f" {key}: {val}")
case "eval_all":
print("Evaluation Phase")
rows = []
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
for ckpFile in pathCkptFile.glob("*.ckpt"):
print(ckpFile.name)
Model, ModelCAM = getModel(args.Model, InitWeight, TypeLoss, ckpFile)
#ViewImgDINO(trainDataset, Dataset.std, Dataset.mean)
if isinstance(Model, (M.RARP_NVB_MultiClassModel, M.RARP_NVB_DINO_MultiTask_v2, M.RARP_NVB_DINO_MultiTask_MultiLabel, M.RARP_Hybrid_TS_LR)):
numClass = 4 if isinstance(Model, M.RARP_NVB_DINO_MultiTask_v2) else 2
numLabel = 2 if isinstance(Model, (M.RARP_NVB_DINO_MultiTask_MultiLabel, M.RARP_Hybrid_TS_LR)) else None
temp = Calc_EvalMulticlass_table(Model, Test_DataLoader, False, ckpFile.name, NumClasses=numClass, Num_Label=numLabel)
else:
temp = Calc_Eval_table(
Model,
Test_DataLoader,
args.JIndex,
ckpFile.name,
Add_TestDataset=Add_Test_DataLoader,
extraData=(args.Metadata is not None),
PseudoLabel=False,
dataSetInfo=Dataset
)
rows += temp
if args.CAM and ModelCAM is not None:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
print("CAM")
ShowCAM(ModelCAM, testCAMDataset, Dataset.mean, Dataset.std, ckpFile.name)
df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])
df.style.highlight_max(color="red", axis=0)
print(df)
plt.show()
case _:
print("Evaluation Phase")
trainLog = [args.Log_version, args.Log_epoch, args.Log_step]
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{trainLog[0]}/checkpoints/epoch={trainLog[1]}-step={trainLog[2]}.ckpt")
Calc_Eval(Model.load_from_checkpoint(pathCkptFile))
if args.CAM:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ShowCAM(ModelCAM.load_from_checkpoint(pathCkptFile, strict=False), testCAMDataset, Dataset.mean, Dataset.std, pathCkptFile.name)
plt.show()