Newer
Older
RARP / RARP_NVB.py
@delAguila delAguila on 20 May 59 KB Video Extraf frame
import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import torchmetrics
import lightning as L
from lightning.pytorch import seed_everything
from lightning.pytorch.tuner import Tuner
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import Loaders
import defs
import argparse
import seaborn as sn 
import Models as M
import pandas as pd
import warnings
from ultralytics import YOLO
import yaml
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import japanize_matplotlib

torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True

def objective(trail: optuna.trial.Trial) -> float:
    lr = trail.suggest_float("lr", 1e-4, 1e-3, log=True)
    l1 = trail.suggest_float("L1", 1e-6, 1e-3, log=True)
    Alpha = trail.suggest_float("W_Apha", 0, 1, step=0.05)
    Thao = trail.suggest_float("Thao_KD", 1, 7, step=0.25)
        
    Trainer_OP = L.Trainer(
        logger=TensorBoardLogger(save_dir=f"./{LogFileName}", name="Tune"),
        #enable_checkpointing=False,
        max_epochs=MaxEpochs,
        accelerator="auto",
        log_every_n_steps=5,  
        devices=1,
        callbacks=[PyTorchLightningPruningCallback(trail, monitor="val_acc"), callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=2, mode='max')],
    )
    
    hyperparameters = dict(
        lr = lr,
        L1 = l1,
        Alpha = Alpha,
        Beta = 1 - Alpha,
        Thao = Thao
    )
    
    ModelOP, _ = getModel(
        args.Model, 
        InitWeight, 
        TypeLoss,
        OptConfig=hyperparameters
    )
        
    Trainer_OP.logger.log_hyperparams(hyperparameters)
    
    Trainer_OP.fit(ModelOP, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
    Trainer_OP.test(ModelOP, dataloaders=Test_DataLoader, ckpt_path="best")
    
    return Trainer_OP.callback_metrics["test_acc"].item()

def Calc_Eval_table_New(TrainModel:M.RARP_NVB_Model):
    TrainModel.to(device)
    TrainModel.eval()
    
    Predictions = []
    Labels = []

    if isinstance(TrainModel, M.RARP_NVB_Model_test2):
        with torch.no_grad():
            for img, label in tqdm(iter(Test_DataLoader)):
                img = img.float().to(device)
                label = label.to(device)
                
                pred = TrainModel(img)
                Predictions.append(torch.softmax(pred, dim=1))
                Labels.append(label)
            
        Predictions = torch.cat(Predictions)
        Labels = torch.cat(Labels)
        
        print(Predictions, Labels)
        
        acc = torchmetrics.Accuracy("multiclass", num_classes=2).to(device)(Predictions, Labels)
        precision = torchmetrics.Precision("multiclass", num_classes=2).to(device)(Predictions, Labels)
        recall = torchmetrics.Recall("multiclass", num_classes=2).to(device)(Predictions, Labels)
        auc = torchmetrics.AUROC("multiclass", num_classes=2).to(device)(Predictions, Labels)
        f1Score = torchmetrics.F1Score("multiclass", num_classes=2).to(device)(Predictions, Labels)

        return [acc.item(), precision.item(), recall.item(), f1Score.item(), auc.item()]

    with torch.no_grad():
        for img, label in tqdm(iter(Test_DataLoader)):
            img = img.float().to(device)
            label = label.float().to(device)
            
            pred = TrainModel(img)
            Predictions.append(torch.sigmoid(pred.squeeze(1)))
            Labels.append(label)
            
    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels)
        
    print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    #cm = torchmetrics.ConfusionMatrix('binary')(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)

    return [acc.item(), precision.item(), recall.item(), f1Score.item(), auc.item()]

def encode_2labels_4classes (x:torch, th:float = 0.5):
    if x.dtype == torch.float:
        x = (x > 0.5) *1
    
    r, l = x
            
    if r == 0 and l == 0:
        return 0
    elif r == 1 and l == 0:
        return 1
    elif r == 0 and l == 1:
        return 2
    elif r == 1 and l == 1:
        return 3
    else:
        return -1

def Calc_EvalMulticlass_table(TrainModel:M.RARP_NVB_Model,TestDataLoadre:DataLoader, Youden=False, modelName="", NumClasses:int=2, Num_Label:int=None):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []

    with torch.no_grad():
        for data, label in tqdm(iter(TestDataLoadre)):
            data = data.float().to(device)
            label = label.to(device)
            
            if isinstance(TrainModel, M.RARP_NVB_DINO_MultiTask):
                pred, _, _ = TrainModel(data)
                NumClasses = 4 if Num_Label is None else None
            else:
                pred = TrainModel(data)
            Predictions.append(torch.softmax(pred, dim=1) if Num_Label is None else torch.sigmoid(pred))
            Labels.append(label)

    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels)

    print(Predictions, Labels)

    acc = torchmetrics.Accuracy("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
    precision = torchmetrics.Precision("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
    recall = torchmetrics.Recall("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels)
        
    table = [
        ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", f"*{modelName}*"] #.item():.4f
    ]
    
    if Num_Label is not None:
        single_labels_pred = [encode_2labels_4classes(p.cpu()) for p in Predictions]
        single_labels_true = [encode_2labels_4classes(p.cpu()) for p in Labels]
        labels_names = ["なし", "右", "左", "右+左"]
        cm = confusion_matrix(single_labels_true, single_labels_pred, labels=[0,1,2,3])
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels_names)
        disp.plot()
    
    
    #cm2 = torchmetrics.ConfusionMatrix("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)
    #cm2.update(Predictions, Labels)
    #_, ax = cm2.plot()
    #ax.set_title(f"NVB Classifier {modelName}")

    if Youden:
        for i in range(2):
            aucCurve = torchmetrics.ROC("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)
            fpr, tpr, thhols = aucCurve(Predictions, Labels)
            index = torch.argmax(tpr - fpr)
            th2 = (recall + specificty - 1).item()
            th2 = 0.5 if th2 <= 0 else th2
            th1 = thhols[index].item() if i == 0 else th2
            accY = torchmetrics.Accuracy("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels)
            precisionY = torchmetrics.Precision("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels)
            recallY = torchmetrics.Recall("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels)
            specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
            f1ScoreY = torchmetrics.F1Score("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels)
            #cm2 = torchmetrics.ConfusionMatrix("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, threshold=th1).to(device)
            #cm2.update(Predictions, Labels)
            #_, ax = cm2.plot()
            #ax.set_title(f"NVB Classifier (th={th1:.4f})")
            table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
        

    return table

def Calc_Eval_table(
    TrainModel:M.RARP_NVB_Model,
    TestDataLoadre:DataLoader, 
    Youden=False, modelName="", 
    Add_TestDataset:DataLoader=None, 
    extraData:bool=False, 
    PseudoLabel:bool=True,
    dataSetInfo:Loaders.RARP_DatasetCreator = None
):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []
    PseudoLabelTest = None
    ban_ExtraImage = isinstance(testDataset, Loaders.RARP_DatasetFolder_DobleTransform)
     
    with torch.no_grad():
        for data, label in tqdm(iter(TestDataLoadre)):
            if extraData:
                img, extra = data
                img = img.float().to(device)
                extra = extra.float().to(device)
                data = (img, extra)
            elif ban_ExtraImage:
                if len(data) == 3:
                    TData, Sdata, OData = data
                    data = (TData.float().to(device), Sdata.float().to(device), OData.float().to(device))
                else:    
                    TData, Sdata = data
                    data = (TData.float().to(device), Sdata.float().to(device))
            else:
                data = data.float().to(device)
            
            label = label.to(device)
            
            if isinstance(TrainModel, M.RARP_NVB_ResNet50_VAN):
                pred, Plabel, _ = TrainModel(data)
                pred = pred.flatten()
                label = Plabel.int() if PseudoLabel else label
            elif isinstance(TrainModel, (M.RARP_NVB_RN50_VAN_V2, M.RARP_NVB_DINO_MultiTask)):
                pred, features, new_img = TrainModel(data)
                pred = pred.flatten()
                #_, axis = plt.subplots(2, 2, figsize=(9, 9))
                #for i in range(2):
                #    for j in range(2):
                #        random_index = np.random.randint(0, len(new_img))
                #        img = new_img[random_index].cpu()
                #        img = img.numpy().transpose((1, 2, 0))
                #        img = np.clip((dataSetInfo.std * img + dataSetInfo.mean) / 255, 0, 1)
                #        img = img[...,::-1].copy()
                        
                #        axis[i][j].imshow(img)                

            elif isinstance(TrainModel, M.RARP_NVB_DINO_RestNet50_Deep):
                DK_PredLabels, _ = TrainModel(data)
                pred, Plabel, _ = DK_PredLabels
                label = Plabel.int() if PseudoLabel else label
            else:
                pred = TrainModel(data).flatten()
                
            Predictions.append(torch.sigmoid(pred))
            Labels.append(label)
      
    if Add_TestDataset is not None:
        with torch.no_grad():
            for data, label in tqdm(iter(Add_TestDataset)):
                data = data.float().to(device)
                label = label.to(device)
                pred = TrainModel(data).flatten()
                Predictions.append(torch.sigmoid(pred))
                Labels.append(label)

    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels)

    print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
        
    table = [
        ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
    ]

    if Youden:
        for i in range(2):
            aucCurve = torchmetrics.ROC("binary").to(device)
            fpr, tpr, thhols = aucCurve(Predictions, Labels)
            index = torch.argmax(tpr - fpr)
            th2 = (recall + specificty - 1).item()
            th2 = 0.5 if th2 <= 0 else th2
            th1 = thhols[index].item() if i == 0 else th2
            accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
            precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
            recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
            specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
            f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
            #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
            #cm2.update(Predictions, Labels)
            #_, ax = cm2.plot()
            #ax.set_title(f"NVB Classifier (th={th1:.4f})")
            table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
        

    return table

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True
   
def CAM(model:M.RARP_NVB_Model, img:torch.Tensor, device):
    with torch.no_grad():
        img = img.to(device).float().unsqueeze(0)
        if isinstance(model, M.RARP_NVB_VAN_CAM):
            pred, feature = model(img, torch.tensor([0.0]).to(device))
        else:
            pred, feature = model(img)
    
    _, c, h, w = feature.shape
    feature = feature.reshape((c, h*w))
    if isinstance(model, (M.RARP_NVB_ResNet18_CAM, M.RARP_NVB_ResNet50_CAM)):
        wParams = list(model.model.fc.parameters())
    elif isinstance(model, (M.RARP_NVB_MobileNetV2_CAM, M.RARP_NVB_EfficientNetV2_CAM)):
        wParams = list(model.model.classifier.parameters())
    elif isinstance(model, M.RARP_NVB_VAN_CAM):
        wParams = list(model.model.head.parameters())
    else:
        raise "Cam Not Implemented"
    pesos = wParams[0].detach()
    cam = torch.matmul(pesos, feature)

    cam = cam - torch.min(cam)
    cam_img =  cam / torch.max(cam)
    cam_img = cam_img.reshape(h, w).cpu()

    return cam_img, torch.sigmoid(pred)

def CAMVisualizer(img, heatmap, pred, label, mean, std, ax, row):
    img = img.numpy().transpose((1, 2, 0))
    heatmap = transforms.functional.resize(heatmap.unsqueeze(0), (img.shape[0], img.shape[1]), antialias=True)[0]

    img = np.clip((std * img + mean) / 255, 0, 1)
    img = img[...,::-1].copy()
    col = 0
    if row > 3:
        col = 2
        
    if row > 7:
        col = 4
        
    ax[row % 4][col + 0].imshow(img)
    ax[row % 4][col + 0].axis('off')
    
    ax[row % 4][col + 1].imshow(img)
    ax[row % 4][col + 1].imshow(heatmap, alpha=0.5, cmap="jet")
    ax[row % 4][col + 1].axis('off') 
    ax[row % 4][col + 1].set_title(f"Pred.: {pred.item():.4f}; Label: {label}")
    
    #plt.title()
    
def ShowCAM(TrainedModel:M.RARP_NVB_Model, DataSet, mean, std, title=""):
    TrainedModel.to(device)
    TrainedModel.eval()
    i = 0
    params = {
        "left":0,
        "bottom":0.01,
        "right":1,
        "top":0.914,
        "wspace":0,
        "hspace":0.164
    }
    fig, axis = plt.subplots(4, 6, gridspec_kw=params)
    with torch.no_grad():
        if len(DataSet) > 12:
            ix = np.unique(DataSet.targets, return_counts=True)[1]
            NOTNVB_Indexs = np.asarray(range(ix[0]))
            NVB_Indexs = np.asarray(range(ix[0], ix[0] + ix[1]))
            #np.random.shuffle(NOTNVB_Indexs)
            #np.random.shuffle(NVB_Indexs)
            
            for j, index in enumerate(NOTNVB_Indexs):
                if j == 6:
                    break
                img, label = DataSet[index]
                cam, pred = CAM(TrainedModel, img, device)
                CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
                i += 1
                
            for j, index in enumerate(NVB_Indexs):
                if j == 6:
                    break
                img, label = DataSet[index]
                cam, pred = CAM(TrainedModel, img, device)
                CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
                i += 1
        else:        
            for img, label in tqdm(DataSet):
                cam, pred = CAM(TrainedModel, img, device)
                CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
                i += 1
            
    fig.suptitle(title)
    
def Calc_Eval(TrainModel:M.RARP_NVB_Model):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []

    with torch.no_grad():
        for data, label in tqdm(testDataset):
            data = data.to(device).float().unsqueeze(0)
            pred = torch.sigmoid(TrainModel(data)[0].cpu())
            Predictions.append(pred)
            Labels.append(label)

    Predictions = torch.cat(Predictions)
    Labels = torch.tensor(Labels).int()

    print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary')(Predictions, Labels)
    precision = torchmetrics.Precision('binary')(Predictions, Labels)
    recall = torchmetrics.Recall('binary')(Predictions, Labels)
    cm = torchmetrics.ConfusionMatrix('binary')(Predictions, Labels)
    auc = torchmetrics.AUROC('binary')(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary')(Predictions, Labels)

    print(f"Val Accuracy: {acc:.4f}")
    print(f"Val Precision: {precision:.4f}")
    print(f"Val Recall: {recall:.4f}")
    print(f"F1 Score: {f1Score:.4f}")
    print(f"AUROC: {auc:.4f}")
    print(testDataset.classes)

    ax = sn.heatmap(cm, cmap="Greens", cbar=False, annot=True, annot_kws={"size": 18}, fmt='g', xticklabels=testDataset.classes, yticklabels=testDataset.classes)
    ax.set_title(f"NVB Classifier Split #{args.Fold+1}")  
    ax.set_xlabel('Predict')  
    ax.set_ylabel('True')  
    plt.show()

def getModel (
    Model_ID:int = 0, 
    InitWeight=torch.tensor([1,1]), 
    TypeLoss:M.TypeLossFunction = M.TypeLossFunction.CrossEntropy,
    Ckpt_File:Path = None, 
    OptConfig:dict = {},
    inputNeurons:int = 4,
    mean:float = None,
    std:float = None
):
    Model = None
    ModelCAM = None
    match Model_ID:
        case 0:
            Model = M.RARP_NVB_ResNet50(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_ResNet50.load_from_checkpoint(ckpFile)
            ModelCAM = None if Ckpt_File is None else M.RARP_NVB_ResNet50_CAM.load_from_checkpoint(ckpFile, strict=False)
        case 1:
            Model = M.RARP_NVB_ResNet18(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_ResNet18.load_from_checkpoint(ckpFile)
            ModelCAM = None if Ckpt_File is None else M.RARP_NVB_ResNet18_CAM.load_from_checkpoint(ckpFile, strict=False)
        case 2:
            Model = M.RARP_NVB_MobileNetV2(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_MobileNetV2.load_from_checkpoint(ckpFile)
            ModelCAM = None if Ckpt_File is None else M.RARP_NVB_MobileNetV2_CAM.load_from_checkpoint(ckpFile, strict=False)
        case 3:
            Model = M.RARP_NVB_EfficientNetV2(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_EfficientNetV2.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 4:
            Model = M.RARP_NVB_Vit_b_16(InitWeight, TypeLoss) if Ckpt_File is None else M.RARP_NVB_Vit_b_16.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 5:
            Model = M.RARP_NVB_DenseNet169(InitWeight, TypeLoss) if Ckpt_File is None else M.RARP_NVB_DenseNet169.load_from_checkpoint(ckpFile)
            ModelCAM = None 
        case 6: 
            Model = M.RARP_NVB_ResNet50_V1(InitWeight, TypeLoss, schedulerLR=args.DyLr,  InputNeurons=inputNeurons) if Ckpt_File is None else M.RARP_NVB_ResNet50_V1.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 7:
            with open(f"train-EFold{args.Fold}.yaml") as file:
                configFile = yaml.load(file, Loader=yaml.FullLoader)
            
            Models = []
            
            for models in configFile["models"]:
                match models:
                    case "ResNet50":
                        if configFile["models"][models] is not None:
                            for pathckpt in configFile["models"][models]:
                                #Models.append(M.RARP_NVB_ResNet50.load_from_checkpoint(Path(pathckpt), strict=False))
                                Models.append(M.RARP_NVB_ResNet50(InitWeight, TypeLoss))
                    case "ResNet18":
                        if configFile["models"][models] is not None:
                            for pathckpt in configFile["models"][models]:
                                #Models.append(M.RARP_NVB_ResNet18.load_from_checkpoint(Path(pathckpt), strict=False))
                                Models.append(M.RARP_NVB_ResNet18(InitWeight, TypeLoss))
                    case "MobileNetV2":
                        if configFile["models"][models] is not None:
                            for pathckpt in configFile["models"][models]:
                                #Models.append(M.RARP_NVB_MobileNetV2.load_from_checkpoint(Path(pathckpt), strict=False))
                                Models.append(M.RARP_NVB_MobileNetV2(InitWeight, TypeLoss))
                    case "EfficientNetV2":
                        if configFile["models"][models] is not None:
                            for pathckpt in configFile["models"][models]:
                                #Models.append(M.RARP_NVB_EfficientNetV2.load_from_checkpoint(Path(pathckpt), strict=False))
                                Models.append(M.RARP_NVB_EfficientNetV2(InitWeight, TypeLoss))
                    case "DenseNet169":
                        if configFile["models"][models] is not None:
                            for pathckpt in configFile["models"][models]:
                                #Models.append(M.RARP_NVB_DenseNet169.load_from_checkpoint(Path(pathckpt), strict=False))
                                Models.append(M.RARP_NVB_DenseNet169(InitWeight, TypeLoss))
                    case _:
                        pass
                    
            print (f"{len(Models)} models Loaded")
            
            Model = M.RARP_Ensemble(Models, InitWeight, TypeLoss, lr=1e-3)
            ModelCAM = None
        case 8:
            Model = M.RARP_NVB_DaVit(InitWeight, TypeLoss) if Ckpt_File is None else M.RARP_NVB_DaVit.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 9:
            Model = M.RARP_NVB_ResNet50_Deep(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_ResNet50_Deep.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 10:
            Model = M.RARP_NVB_EfficientNetV2_Deep(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_EfficientNetV2_Deep.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 11: 
            Model = M.RARP_NVB_ResNet50_V2(InitWeight, TypeLoss, schedulerLR=args.DyLr,  InputNeurons=inputNeurons) if Ckpt_File is None else M.RARP_NVB_ResNet50_V2.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 12: 
            Model = M.RARP_NVB_ResNet50_V3(InitWeight, TypeLoss, schedulerLR=args.DyLr,  InputNeurons=inputNeurons) if Ckpt_File is None else M.RARP_NVB_ResNet50_V3.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 13: 
            Model = M.RARP_NVB_VAN(InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_VAN.load_from_checkpoint(ckpFile)
            ModelCAM = None if Ckpt_File is None else M.RARP_NVB_VAN_CAM.load_from_checkpoint(ckpFile, strict=False)
        case 14:
            TestModel = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
            TestModel.fc = torch.nn.Linear(in_features=TestModel.fc.in_features, out_features=4)
            Model = M.RARP_NVB_MultiClassModel(
                None, 
                Model=TestModel, 
                Num_Classes=4, 
            ) if Ckpt_File is None else M.RARP_NVB_MultiClassModel.load_from_checkpoint(ckpFile,Model=TestModel)
            ModelCAM = None
        case 15:
            #if OptConfig.get("lr") is None:
            #    OptConfig = dict(
            #        lr = None, #0.00015278, #1.53E-4,
            #        L1 = None, #0.0000020505, #2.05E-6,
            #        Alpha = 0.45,
            #        Gamma = 0.5,
            #        Thao = 5#2
            #    )
            Model = M.RARP_NVB_ResNet50_VAN(
                "./log_ResNet50_X10/lightning_logs/version_8/checkpoints/RARP-epoch=5.ckpt", 
                #"./log_X10/lightning_logs/version_0/checkpoints/RARP-epoch=39.ckpt", 
                0.5,
                InitWeight, 
                TypeLoss, 
                schedulerLR=args.DyLr,
                PseudoLables=False,
                HParameter=OptConfig
            ) if Ckpt_File is None else M.RARP_NVB_ResNet50_VAN.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 16:
            Model = M.RARP_NVB_SSL_RestNet50_Deep("./log_ResNet50_Deep_X10/lightning_logs/version_3/checkpoints/RARP-epoch=9.ckpt", 0.5, InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_SSL_RestNet50_Deep.load_from_checkpoint(ckpFile)
            #Model = M.RARP_NVB_SSL_RestNet50_Deep("./log_ResNet50DeepSSL_X10/lightning_logs/version_8/checkpoints/RARP-epoch=42.ckpt", 0.5, InitWeight, TypeLoss, schedulerLR=args.DyLr) if Ckpt_File is None else M.RARP_NVB_SSL_RestNet50_Deep.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 17:
            Model = M.RARP_NVB_DINO_RestNet50_Deep(
                "./log_ResNet50_Deep_X10/lightning_logs/version_3/checkpoints/RARP-epoch=9.ckpt", 
                threshold=0.5, 
                TypeLoss=TypeLoss,
                L1=1.31E-04,
            ) if Ckpt_File is None else M.RARP_NVB_DINO_RestNet50_Deep.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 18:
            Model = M.RARP_NVB_DINO_VAN(
                "./log_ResNet50_Deep_X10/lightning_logs/version_3/checkpoints/RARP-epoch=9.ckpt", 
                threshold=0.5, 
                TypeLoss=TypeLoss,
                #L1=1.31E-04
            ) if Ckpt_File is None else M.RARP_NVB_DINO_VAN.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 19:
            Model = M.RARP_NVB_RN50_VAN_V2(
                #"./log_X10/lightning_logs/version_0/checkpoints/RARP-epoch=39.ckpt",
                "./log_ResNet50_X10/lightning_logs/version_8/checkpoints/RARP-epoch=5.ckpt", 
                0.5,
                InitWeight, 
                TypeLoss, 
                schedulerLR=args.DyLr,
                PseudoLables=False,
                HParameter=OptConfig, std=std, mean=mean
            ) if Ckpt_File is None else M.RARP_NVB_RN50_VAN_V2.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 20:
            Model = M.RARP_NVB_DINO_MultiTask_Unet(
                TypeLoss,
                std=std,
                mean=mean,
                L1= 1.31E-04,
                L2= 0,
                SoftAdptAlgo=0,
            ) if Ckpt_File is None else M.RARP_NVB_DINO_MultiTask_Unet.load_from_checkpoint(ckpFile)
            ModelCAM = None
        case 21:
            Model = None if Ckpt_File is None else M.RARP_Hybrid_TS_LR(ckpFile, masked=True)
            ModelCAM = None
        case _:
            raise Exception("Model Not yet Implemented")
        
    return (Model, ModelCAM)

def ViewImg(dataset, std, mean):
    _, axis = plt.subplots(2, 2, figsize=(9, 9))
    for i in range(2):
        for j in range(2):
            random_index = np.random.randint(0, 44)
            img, label = dataset[random_index]
            img, _ = img
            img = img.numpy().transpose((1, 2, 0))
            img = np.clip((std * img + mean) / 255, 0, 1)
            img = img[...,::-1].copy()
            
            axis[i][j].imshow(img)
            axis[i][j].set_title(f"Label:{label}")
            
def ViewImgDINO(dataset, std, mean):
    _, axis = plt.subplots(4, 7, figsize=(25, 25))
    for i in range(4):
        random_index = np.random.randint(0, len(dataset.targets))
        imgCrops, label = dataset[random_index]
        for j, img in enumerate(imgCrops):
            img = img.numpy().transpose((1, 2, 0))
            img = np.clip((std * img + mean) / 255, 0, 1)
            img = img[...,::-1].copy()
            
            axis[i][j].imshow(img)
            axis[i][j].set_title(f"Label:{label}")
            axis[i][j].set_axis_off()
    
    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
    parser.add_argument("--Fold", type=int, default=0)
    parser.add_argument("--Workers", type=int, default=0)
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("--Model", type=int, default=0, help="0 = ResNet18, 1 = ResNet50")
    parser.add_argument("-lv", "--Log_version", type=int)
    parser.add_argument("-le", "--Log_epoch", type=int)
    parser.add_argument("-ls", "--Log_step", type=int)
    parser.add_argument("--Remove_Blackbar", type=bool, default=True)
    parser.add_argument("--BGR2RGB", type=bool, default=False)
    parser.add_argument("--CAM", type=bool, default=False)
    parser.add_argument("-roi", "--Use_ROI_Dataset", type=int, default=0)
    parser.add_argument("-s", "--imgSlice_pct", type=float, default=None)
    parser.add_argument("-ns", "--Num_Slices", type=int, default=4)
    parser.add_argument("-wl", "--Wloss",type=bool, default=False)
    parser.add_argument("--sClass",type=int, default=None)
    parser.add_argument("-tl", "--TypeLoss", type=int, default=0)
    parser.add_argument("-cs", "--ColorSpace", type=int, default=None)
    parser.add_argument("--JIndex", type=bool, default=False)
    parser.add_argument("-me", "--maxEpochs", type=int, default=None)
    parser.add_argument("-lc", "--LoadChkpt", type=str, default=None)
    parser.add_argument("--AddTestSet", type=str, default=None)
    parser.add_argument("--Metadata", type=str, default=None)
    parser.add_argument("--DyLr", type=bool, default=False)
    parser.add_argument("-lr", type=float, default=1e-4)
    parser.add_argument("--ExtraNeurons", type=int, default=4)
    parser.add_argument("--ExtraLabels", type=str, default=None)
    parser.add_argument("--Roi_Mask_Model", type=str, default=None)
    

    args = parser.parse_args()
    
    if args.CAM and args.Phase == "train":
        raise Exception("Clases Activation Clases only in eval o eval_all")
    
    match args.Use_ROI_Dataset:
        case 0:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_main",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 1:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Crop",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSetCrop",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 256  
        case 2:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Crop1",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSetCrop1",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 256 
        case 3:
            YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_main",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_YOLO",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace,
                ROI_Yolo=YoloModel
            )
            cropSize = 256
        case 4:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_big",
                FoldSeed=505,
                createFile=True,
                SavePath="./DatasetBig",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 6:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Full",
                FoldSeed=505,
                createFile=True,
                SavePath="./DatasetFull",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 7:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_smallBalaced",
                FoldSeed=505,
                createFile=True,
                SavePath="./DatasetSmallBalanced",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 8:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_big_2",
                FoldSeed=505,
                createFile=True,
                SavePath="./DatasetBig2",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 9:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Ando",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSetAndo",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 10:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_smallBalacedCrop",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_SB_Crop",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 256
        case 11:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_AndoCrop",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_SB_Ando_Crop",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 256
        case 12:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Ando_All_Crop",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_Ando_Crop",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 256
        case 13:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Ando_All",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_Ando_All",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 14:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Ando_AllNewLabels",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_New_labels",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 15:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 16:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_crop",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_Crop",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 256
        case 17:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Ando_All_no20",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_Ando_All_20",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 18:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Ando_All_no20Crop",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_AndoAll20_crop",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 19:
            ROI_model = M.RARP_NVB_ROI_Mask_Unet.load_from_checkpoint(Path(args.Roi_Mask_Model))
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Ando_All_no20",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_AndoAll20_mask",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace,
                ROI_Mask=ROI_model
            )
            cropSize = 720
        case 21:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Kpts",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_Kpts_FullSize",
                Fold=5,
                removeBlackBar=False,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace,
                copyKpoints=True
            )
            cropSize = 720
        case 5:
            YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_big",
                FoldSeed=505,
                createFile=True,
                SavePath="./DatasetBig_YOLO",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace,
                ROI_Yolo=YoloModel
            )
            cropSize = 256
        case 20:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_big_Multiclass",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_Multiclass",
                Fold=5,
                Num_Labels=4,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
            
    Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373]) 
       
    Dataset.CreateFolds()
    
    setup_seed(2023)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    batchSize = 8 #17 #8, 32
    numWorkers = args.Workers
    InitResize = (256,256)
    ImgResize = (224, 224)
    checkPtCallback = callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max')
    ckpLossBest = callbk.ModelCheckpoint(monitor="val_loss", filename="RARP-{epoch}-{val_loss:.2f}", save_top_k=2, mode='min')
    
    
    
    traintransform = torch.nn.Sequential(
        transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),                                            #AQUI se cambio 2024/05/10
        #transforms.RandomCrop(ImgResize),
        transforms.RandomAffine(
            degrees=(-5, 5), scale=(0.9, 1.1), 
            fill=5
        ),
        transforms.RandomHorizontalFlip(1.0),
        transforms.Normalize(Dataset.mean, Dataset.std),
    ).to(device)
    
    traintransformT2 = torch.nn.Sequential(
        transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.RandomCrop(224),
        transforms.RandomErasing(0.8, value="random"),
        transforms.RandomAffine(degrees=(-45, 45), scale=(0.8, 1.2), fill=5),
        transforms.GaussianBlur(5),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)

    Roi_mask_transform = torch.nn.Sequential(
        transforms.Resize((224, 224), antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)

    valtransform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device) if not args.Use_ROI_Dataset in [19,21] else Roi_mask_transform

    testtransform =  torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device) if not args.Use_ROI_Dataset in [19,21] else Roi_mask_transform
    
    TrainDINOTransforms = Loaders.RARP_DINO_Augmentation(
        GloblaCropsScale = (0.4, 1),
        LocalCropsScale = (0.05, 0.4),
        NumLocalCrops = 4,
        Size = 224, 
        device = device,
        mean = Dataset.mean,
        std = Dataset.std,
        Tranform_0 = testtransform if args.Model == 20 else None
    )
    
    rootFile = Dataset.CVS_File.parent.parent/f"fold_{args.Fold}"
    
    Add_Test_DataLoader = None
    
    traintransform = TrainDINOTransforms if args.Model in (17, 18, 20) else traintransform
    
    if args.AddTestSet is not None and args.Metadata is None:        
        Add_TestDataset = torchvision.datasets.DatasetFolder(
            str (Path(args.AddTestSet)/f"fold_{args.Fold}"/"test"),
            loader=defs.load_file_tensor,
            extensions="npy",
            transform=testtransform
        )
        Add_Test_DataLoader = DataLoader(
            Add_TestDataset, 
            batch_size=batchSize, 
            num_workers=numWorkers, 
            shuffle=False, 
            pin_memory=True
        )
    
    if args.Metadata is None:
        if args.Model in (15, 16, 19):
            trainDataset = Loaders.RARP_DatasetFolder_DobleTransform(
                str (rootFile/"train"),
                loader=defs.load_file_tensor,
                extensions="npy",
                transformT1=traintransform,
                transformT2=traintransformT2,
                passOriginal= testtransform if args.Model == 19 else None
            )
            
            valDataset = Loaders.RARP_DatasetFolder_DobleTransform(
                str (rootFile/"val"),
                loader=defs.load_file_tensor,
                extensions="npy",
                transformT1=valtransform,
                passOriginal= testtransform if args.Model == 19 else None
            )

            testDataset = Loaders.RARP_DatasetFolder_DobleTransform(
                str (rootFile/"test"),
                loader=defs.load_file_tensor,
                extensions="npy",
                transformT1=testtransform,
                passOriginal= testtransform if args.Model == 19 else None
            )
        elif args.Use_ROI_Dataset == 21:
            trainDataset = Loaders.RARP_DatasetFolder_ROIExtractor_OnlyROI(
                str (rootFile/"train"),
                loader=defs.load_file,
                extensions="npy",
                transform=traintransform,
                root_kpts= rootFile / "../../DataSet_Kpts"
            )
            
            valDataset = Loaders.RARP_DatasetFolder_ROIExtractor_OnlyROI(
                str (rootFile/"val"),
                loader=defs.load_file,
                extensions="npy",
                transform=valtransform,
                root_kpts= rootFile / "../../DataSet_Kpts"
            )

            testDataset = Loaders.RARP_DatasetFolder_ROIExtractor_OnlyROI(
                str (rootFile/"test"),
                loader=defs.load_file,
                extensions="npy",
                transform=testtransform,
                root_kpts= rootFile / "../../DataSet_Kpts"
            )
        else:  
            trainDataset = torchvision.datasets.DatasetFolder(
                str (rootFile/"train"),
                loader=defs.load_file_tensor,
                extensions="npy",
                transform=traintransform
            )
            
            valDataset = torchvision.datasets.DatasetFolder(
                str (rootFile/"val"),
                loader=defs.load_file_tensor,
                extensions="npy",
                transform=valtransform
            )

            testDataset = torchvision.datasets.DatasetFolder(
                str (rootFile/"test"),
                loader=defs.load_file_tensor,
                extensions="npy",
                transform=testtransform
            )
    else:
        DumpCSV = pd.read_csv(Dataset.CVS_File)
        Extradata = pd.read_excel(Path(args.Metadata))

        Extradata["name"] =  Extradata["列1"].astype(str) + ".tiff"
        Extradata = Extradata.drop(columns=["列1"])

        DumpCSV["raw_name"] = "Img0-" + DumpCSV["id"].astype(str) + ".npy"
        DumpCSV = DumpCSV.drop(columns=["id", "path", "mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3"])

        NewData = pd.merge(Extradata, DumpCSV, on="name")
        
        trainDataset = Loaders.RARP_DatasetFolder_ExtraData(
            str (rootFile/"train"),
            loader=defs.load_file_tensor,
            Extra_Data=NewData,
            Extra_Data_leg = args.ExtraNeurons,
            extensions="npy",
            transform=traintransform
        )
        
        valDataset = Loaders.RARP_DatasetFolder_ExtraData(
            str (rootFile/"val"),
            loader=defs.load_file_tensor,
            Extra_Data=NewData,
            Extra_Data_leg = args.ExtraNeurons,
            extensions="npy",
            transform=valtransform
        )

        testDataset = Loaders.RARP_DatasetFolder_ExtraData(
            str (rootFile/"test"),
            loader=defs.load_file_tensor,
            Extra_Data=NewData,
            Extra_Data_leg = args.ExtraNeurons,
            extensions="npy",
            transform=testtransform
        )
        
    if args.ExtraLabels is not None:
        DumpCSV = pd.read_csv(Dataset.CVS_File)
        Extradata = pd.read_excel(Path(args.ExtraLabels))
        
        DumpCSV["raw_name"] = "Img0-" + DumpCSV["id"].astype(str) + ".npy"
        DumpCSV = DumpCSV.drop(columns=["mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3", "path", "class", "label"])
        
        outPut = pd.merge(Extradata, DumpCSV, on="name", how="right")
        
        trainDataset = Loaders.RARP_DatasetFolder_ExtraLabel(
            str (rootFile/"train"),
            loader=defs.load_file_tensor,
            Extra_Data=outPut,
            extensions="npy",
            transform=traintransform
        )
        
        valDataset = Loaders.RARP_DatasetFolder_ExtraLabel(
            str (rootFile/"val"),
            loader=defs.load_file_tensor,
            Extra_Data=outPut,
            extensions="npy",
            transform=valtransform
        )

        testDataset = Loaders.RARP_DatasetFolder_ExtraLabel(
            str (rootFile/"test"),
            loader=defs.load_file_tensor,
            Extra_Data=outPut,
            extensions="npy",
            transform=testtransform
        )
    
    Train_DataLoader = DataLoader(
        trainDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=True, 
        drop_last=True,
        pin_memory=True,
        persistent_workers=numWorkers>0
    )
    Val_DataLoader = DataLoader(
        valDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=numWorkers>0
    )
    Test_DataLoader = DataLoader(
        testDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=numWorkers>0
    )
    
    if args.CAM:
        testCAMDataset = torchvision.datasets.DatasetFolder(
            str (rootFile/"test"),
            loader=defs.load_file_tensor,
            extensions="npy",
            transform=torch.nn.Sequential(
                transforms.Resize((224, 224), antialias=True),
                transforms.Normalize(Dataset.mean, Dataset.std)
            ).to(device)
        )
        
        TestCAM_DataLoader = DataLoader(
            testCAMDataset, 
            batch_size=batchSize, 
            num_workers=numWorkers, 
            shuffle=False, 
            pin_memory=True
        )

    print(f"Currtent Fold Splits {Dataset.Splits[args.Fold]}")
    print(f"Unique Values in sets")
    info = np.unique(trainDataset.targets, return_counts=True), np.unique(valDataset.targets, return_counts=True), np.unique(testDataset.targets, return_counts=True)
    print(info)
    
    neg = 0
    pos = 0
    for i in info:
        neg += i[1][0]
        pos += i[1][1]
        
    total = neg + pos
    factor = 2 if args.TypeLoss == 1 else 1
    InitWeight = torch.tensor([total/(neg * factor), total/(pos * factor)]).to(device) if args.Wloss else None
    if InitWeight is not None:
        print(f"Weights {InitWeight}")
    TypeLoss = M.TypeLossFunction(args.TypeLoss)
    Model, ModelCAM = getModel(
        args.Model, 
        InitWeight, 
        TypeLoss,
        mean=Dataset.mean, 
        std=Dataset.std
    )    
    NameModel = type(Model).__name__
    print(f"Model Used: {NameModel}")
    LogFileName = f"{args.Log_Name}" 
    
    MaxEpochs = 150
    if args.Model == 4:
        MaxEpochs = 150
        
    if args.maxEpochs is not None:
        MaxEpochs = args.maxEpochs

    #warnings.simplefilter("ignore")

    match(args.Phase):
        case "train":
            trainer = L.Trainer(
                deterministic=True,
                #gradient_clip_val=2.0,
                accelerator='gpu', 
                devices=1, 
                logger=CSVLogger(save_dir=f"./{LogFileName}", name="Tune") if args.Phase == "tune" else TensorBoardLogger(save_dir=f"./{LogFileName}"),
                log_every_n_steps=5,  
                #callbacks=[checkPtCallback, StepDropout(5,  base_drop_rate=0.2, gamma=0.05, ascending=True)],#if args.Model == 4 else checkPtCallback, 
                callbacks=[checkPtCallback, callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)],
                max_epochs=MaxEpochs,
            )
            print("Train Phase")
            trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader, ckpt_path=args.LoadChkpt)
            trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")
        case "tune":
            print("Tuning")
            
            pruner = optuna.pruners.SuccessiveHalvingPruner()#MedianPruner()
            sampler = optuna.samplers.GPSampler(seed=2023) if args.Log_step == 1 else optuna.samplers.TPESampler(seed=2023)
            study = optuna.create_study(direction="maximize", pruner=pruner, sampler=sampler)
            study.optimize(objective, n_trials=100)
            
            print("Number of finished trials: {}".format(len(study.trials)))
            print("Best trial:")
            trial = study.best_trial
            print(f"   Value: {trial.value}")
            print(f"   Paramas: ")
            for key, val in trial.params.items():
                print(f"      {key}: {val}")
        case "eval_all":
            print("Evaluation Phase")
            rows = []
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
            for ckpFile in pathCkptFile.glob("*.ckpt"):
                print(ckpFile.name)
                Model, ModelCAM = getModel(args.Model, InitWeight, TypeLoss, ckpFile)
                
                #ViewImgDINO(trainDataset, Dataset.std, Dataset.mean)
                
                if isinstance(Model, (M.RARP_NVB_MultiClassModel, M.RARP_NVB_DINO_MultiTask_v2, M.RARP_NVB_DINO_MultiTask_MultiLabel, M.RARP_Hybrid_TS_LR)):
                    numClass = 4 if isinstance(Model, M.RARP_NVB_DINO_MultiTask_v2) else 2
                    numLabel = 2 if isinstance(Model, (M.RARP_NVB_DINO_MultiTask_MultiLabel, M.RARP_Hybrid_TS_LR)) else None
                    temp = Calc_EvalMulticlass_table(Model, Test_DataLoader, False, ckpFile.name, NumClasses=numClass, Num_Label=numLabel)
                else:
                    temp = Calc_Eval_table(
                        Model, 
                        Test_DataLoader, 
                        args.JIndex, 
                        ckpFile.name, 
                        Add_TestDataset=Add_Test_DataLoader,  
                        extraData=(args.Metadata is not None), 
                        PseudoLabel=False,
                        dataSetInfo=Dataset
                    )
                rows += temp
                if args.CAM and ModelCAM is not None:
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        print("CAM")
                        ShowCAM(ModelCAM, testCAMDataset, Dataset.mean, Dataset.std, ckpFile.name)
            
            df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])        
            df.style.highlight_max(color="red", axis=0)
            print(df)
            plt.show()
        case _:
            print("Evaluation Phase")
            trainLog = [args.Log_version, args.Log_epoch, args.Log_step] 
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{trainLog[0]}/checkpoints/epoch={trainLog[1]}-step={trainLog[2]}.ckpt")
            Calc_Eval(Model.load_from_checkpoint(pathCkptFile))
            if args.CAM:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    ShowCAM(ModelCAM.load_from_checkpoint(pathCkptFile, strict=False), testCAMDataset, Dataset.mean, Dataset.std, pathCkptFile.name)
                    plt.show()