import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models_video as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score


torch.backends.cuda.matmul.allow_tf32 = True  
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True



def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True

def Calc_Eval_table(
    TrainModel,
    TestDataLoadre:DataLoader, 
    Youden=False, 
    modelName="", 
):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []
     
    with torch.no_grad():
        for data, label, mask, key_frame in tqdm(iter(TestDataLoadre)):
                       
            data = data.to(device, dtype=torch.float32)
            key_frame = key_frame.to(device, dtype=torch.float32)
            mask = mask.to(device)
            label = label.to(device)
            
            #pred, *_ = TrainModel(data)
            pred, _ = TrainModel(data, key_frame, mask)
            pred = pred.flatten()
                
            Predictions.append(torch.sigmoid(pred))
            Labels.append(label)
      
    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels).int()

    #print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
        
    table = [
        ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
    ]

    if Youden:
        for i in range(2):
            aucCurve = torchmetrics.ROC("binary").to(device)
            fpr, tpr, thhols = aucCurve(Predictions, Labels)
            index = torch.argmax(tpr - fpr)
            th2 = (recall + specificty - 1).item()
            th2 = 0.5 if th2 <= 0 else th2
            th1 = thhols[index].item() if i == 0 else th2
            accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
            precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
            recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
            specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
            f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
            #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
            #cm2.update(Predictions, Labels)
            #_, ax = cm2.plot()
            #ax.set_title(f"NVB Classifier (th={th1:.4f})")
            table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
        

    return table
    
def ensure_list(x):
    # Converts tensor/list/tuple to python list
    if torch.is_tensor(x):
        return x.detach().cpu().tolist()
    if isinstance(x, (list, tuple)):
        return list(x)
    return [x]

def compute_uniform_starts(T: int, L: int, W: int):
    stride = (T - L) / max(W - 1, 1)
    starts = []
    for i in range(W):
        s = int(round(i * stride))
        s = min(s, T - L)
        starts.append(s)
    return starts

def get_topk_nearest_to_centroids(df: pd.DataFrame, X_space: np.ndarray, kmeans, topk: int = 20):
    """
    df: rows correspond 1:1 with X_space (same ordering)
    X_space: embeddings used for KMeans 
    kmeans: fitted sklearn KMeans
    Returns: df_top with extra columns: dist_to_centroid, rank_in_cluster
    """
    assert len(df) == X_space.shape[0], "df and X_space must have same #rows"
    C = kmeans.cluster_centers_   # [K, D]
    labels = df["cluster"].to_numpy()

    out_rows = []
    for k in range(C.shape[0]):
        idx = np.where(labels == k)[0]
        if len(idx) == 0:
            continue
        d = np.linalg.norm(X_space[idx] - C[k], axis=1)
        order = np.argsort(d)[:topk]
        chosen = idx[order]

        sub = df.iloc[chosen].copy()
        sub["dist_to_centroid"] = d[order]
        sub["rank_in_cluster"] = np.arange(1, len(sub) + 1)
        out_rows.append(sub)

    df_top = pd.concat(out_rows, axis=0).sort_values(["cluster", "rank_in_cluster"])
    
    return df_top

def save_window_montage(clip_uint8, out_path, title="", n_frames=6):
    """
    clip_uint8: [L, H, W, 3] uint8 (preferred)
    Saves montage with n_frames sampled from the clip.
    """
    L = clip_uint8.shape[0]
    # evenly spaced indices
    idx = np.linspace(0, L - 1, n_frames).round().astype(int)

    fig, axes = plt.subplots(1, n_frames, figsize=(2.2*n_frames, 2.2))
    if n_frames == 1:
        axes = [axes]

    for ax, t in zip(axes, idx):
        img = clip_uint8[t]
        # safety: if CHW, convert to HWC
        if img.ndim == 3 and img.shape[0] == 3 and img.shape[-1] != 3:
            img = np.transpose(img, (1, 2, 0))
        ax.imshow(img)
        ax.set_title(f"t={t}")
        ax.axis("off")

    fig.suptitle(title, fontsize=10)
    plt.tight_layout()
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close(fig)

def load_window_rgb(arrays, case_to_vidx, case_id, start_sec, L):
    vidx = case_to_vidx[case_id]
    arr = arrays[vidx]                 # memmap/ndarray
    clip = arr[start_sec:start_sec+L]  # [L, H, W, 3] uint8
    return clip

def export_cluster_examples(df_top,
                            arrays,
                            case_to_vidx,
                            L_win: int,
                            out_dir = "./cluster_examples",
                            n_frames_per_montage: int = 6):
    
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    
    for row in tqdm(df_top.itertuples(index=False), desc="Centroids Montage"):
        # expects these columns exist in df_top:
        # case_id, cluster, win_start_sec, win_idx, dist_to_centroid, rank_in_cluster
        cid = row.case_id
        k = int(row.cluster)
        s = int(row.win_start_sec)
        w = int(row.win_idx)
        dist = float(row.dist_to_centroid)
        rank = int(row.rank_in_cluster)

        folder = out_dir / f"cluster_{k:02d}"
        folder.mkdir(exist_ok=True)

        clip = load_window_rgb(arrays, case_to_vidx, cid, s, L_win)

        out_path = folder / f"rank{rank:02d}_case{cid}_w{w:02d}_s{s:04d}_d{dist:.4f}.png"
        title = f"cluster={k} rank={rank} case={cid} win={w} start={s}s dist={dist:.4f}"
        save_window_montage(clip, str(out_path.resolve()), title=title, n_frames=n_frames_per_montage)

    print(f"Saved cluster example montages to: {out_dir}")

def extract_and_cluster_windows(
    test_loader,
    encoder,
    device,
    out_dir="./cluster_out",
    n_clusters=4,
    pca_dim=128,
    pca_vis_dim=2,
    random_state=505,
    # fallback params if win_start not in info
    T_total=1200,    # 20 min @ 1 fps
    L_win=None,      # will infer from data if None
    vid_array = None,
    dict_vid_array = None,
    num_samples:int = 1,
    Hybrid_TS=None    
):
    out_dir = Path(out_dir+f"K{n_clusters}")
    out_dir.mkdir(parents=True, exist_ok=True)
    
    features = []
    case_ids = []
    win_idx = []
    win_start = []
    y_video = []

    with torch.no_grad():
        for winds, label, _, info in tqdm(test_loader, desc="Windows Analysis"):
            # winds: [B, N, L, C, H, W]
            B, N, L, C, H, W = winds.shape
            if L_win is None:
                L_win = L

            # middle frame per window
            if num_samples == 1:
                mid = winds[:, :, L // 2]                 # [B, N, C, H, W]
                mid = mid.reshape(B*N, C, H, W)         # [B*N, C, H, W]
            else:
                idx = np.linspace(0, L - 1, num_samples).round().astype(int)
                mid = winds[:, :, idx]
                K = mid.shape[2]
                mid = mid.reshape(B*N*K, C, H, W)         # [B*N*K, C, H, W]
                
            if Hybrid_TS is None:
                f = encoder(mid.to(device, non_blocking=True))  # [B*N, D] or [B*N*K, D]
            else:
                _ = Hybrid_TS(mid.to(device, non_blocking=True))
                _fs = torch.cat((Hybrid_TS.last_conv_output_S, Hybrid_TS.last_conv_output_T), dim=1)
                f = torch.nn.functional.adaptive_avg_pool2d(_fs, (1,1)).flatten(1) 
                
            f = torch.nn.functional.normalize(f, dim=1)
            
            if num_samples > 1:
                f = f.view(B, N, K, -1).mean(dim=2) #mean the K dim
                f = f.reshape(B*N, -1)
            
            features.append(f.cpu())

            # labels: video-level (B,) or (B,1). Repeat per window
            if torch.is_tensor(label):
                label_b = label.detach().cpu().view(B).tolist()
            else:
                label_b = list(label)

            # case_ids: list length B
            cids = info.get("case_id", info.get("case", None))
            cids = ensure_list(cids)

            # window indices
            # Preferred: info["win_idx"] is [B, N]
            if "win_idx" in info:
                widx = info["win_idx"]
                if torch.is_tensor(widx):
                    widx = widx.detach().cpu().view(B, N).tolist()
                # flatten by batch item
                for b in range(B):
                    case_ids.extend([cids[b]] * N)
                    win_idx.extend(widx[b])
                    y_video.extend([label_b[b]] * N)
            else:
                # fallback: use 0..N-1
                for b in range(B):
                    case_ids.extend([cids[b]] * N)
                    win_idx.extend(list(range(N)))
                    y_video.extend([label_b[b]] * N)

            # window start seconds
            # Preferred: info["win_start"] is [B, N]
            if "win_start" in info:
                ws = info["win_start"]
                if torch.is_tensor(ws):
                    ws = ws.detach().cpu().view(B, N).tolist()
                for b in range(B):
                    win_start.extend(ws[b])
            else:
                # fallback: compute from uniform starts, same for every case
                starts = compute_uniform_starts(T=T_total, L=L_win, W=N)
                # repeat for B cases
                win_start.extend(starts * B)

    X = torch.cat(features, dim=0).numpy()  # [TotalWindows, D]

    df = pd.DataFrame({
        "case_id": case_ids,
        "win_idx": win_idx,
        "win_start_sec": win_start,
        "label_nvb": y_video
    })
    df["win_start_min"] = round (df["win_start_sec"] / 60.0, 4)

    # PCA for clustering
    X_pca = PCA(n_components=pca_dim, random_state=random_state).fit_transform(X)
    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto")
    df["cluster"] = kmeans.fit_predict(X_pca)

    sil = silhouette_score(X_pca, df["cluster"].values)
    print(f"Silhouette (PCA-{pca_dim}, K={n_clusters}): {sil:.6f}")
    
    df_top = get_topk_nearest_to_centroids(df, X_pca, kmeans, topk=5)

    # Save table
    csv_path = out_dir / "windows_clusters.csv"
    df.to_csv(csv_path, index=False)
    print(f"Saved: {str(csv_path)}")
    csv_path = out_dir / "cluster_topk_centroid_nearest.csv"
    df_top.to_csv(csv_path, index=False)
    print(f"Saved: {str(csv_path)}")
    
    #Montage of centroids
    export_cluster_examples(df_top, arrays=vid_array, case_to_vidx=dict_vid_array, L_win=L_win, out_dir=(out_dir/"centroids"), n_frames_per_montage=6)

    # Plot 1: PCA-2D scatter colored by cluster
    X_vis = PCA(n_components=pca_vis_dim, random_state=random_state).fit_transform(X)
    plt.figure(figsize=(7, 6))
    plt.scatter(X_vis[:, 0], X_vis[:, 1], c=df["cluster"].values, s=6, alpha=0.5)
    plt.title(f"Window Embeddings (PCA-{pca_vis_dim}) — KMeans K={n_clusters}")
    plt.xlabel("PC1"); plt.ylabel("PC2")
    plt.grid(True, alpha=0.2)
    pca_path = out_dir / "pca2d_clusters.png"
    plt.savefig(str(pca_path.resolve()), dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved: {str(pca_path)}")

    # Plot 2: cluster timeline for a few cases
    # (Pick some cases; you can also pass a list explicitly)
    unique_cases = df["case_id"].unique().tolist()
    show_cases = unique_cases[:min(6, len(unique_cases))]

    _, axes = plt.subplots(len(show_cases), 1, figsize=(10, 2.0 * len(show_cases)), sharex=True)
    if len(show_cases) == 1:
        axes = [axes]

    for ax, cid in zip(axes, show_cases):
        sub = df[df["case_id"] == cid].sort_values("win_start_sec")
        ax.scatter(sub["win_start_min"], sub["cluster"], s=20)
        ax.set_ylabel("Cluster")
        ax.set_title(f"Case {cid} — cluster vs time")
        ax.grid(True)

    axes[-1].set_xlabel("Time (minutes)")
    timeline_path = out_dir / "cluster_timelines.png"
    plt.tight_layout()
    plt.savefig(str(timeline_path.resolve()), dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved: {str(timeline_path)}")

    return df, sil

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'", required=True)
    parser.add_argument("--Fold", type=int, default=0)
    parser.add_argument("-lv","--Log_version", type=int, default=None)
    parser.add_argument("--Workers", type=int, default=0)
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("--CNN_name", type=str, default=None, )
    parser.add_argument("--Temp_Head", type=str, default=None, )
    parser.add_argument("-me", "--maxEpochs", type=int, default=None)
    parser.add_argument("-b", "--Batch_size", type=int, default=8)
    parser.add_argument("--GPU", type=int, default=0)
    parser.add_argument("--pre_train", type=int, default=0)
    parser.add_argument("-k", "--k_windows", type=int, default=1)
    parser.add_argument("--Window_Size", type=int, default=64)
    parser.add_argument("--Num_Window", type=int, default=8)
    parser.add_argument("--cached_features", type=bool, default=False)
    parser.add_argument("--seed", type=int, default=2023)
    
    args = parser.parse_args()
    
    setup_seed(args.seed)
    device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")

    df = pd.read_csv("./Dataset_RARP_video/dataset_videos_folds.csv")

    FOLD = args.Fold
    WORKERS = args.Workers
    BATCH_SIZE = args.Batch_size
    MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
    PRE_TRAIN = args.pre_train != 0 
    K_WIN = args.k_windows
    KEY_FRAME = True if args.Phase != "cluster" else False
    WIN_LENGTH = args.Window_Size
    NUM_WIN = args.Num_Window
    CACHED_FEATURES  = args.cached_features
    NO_NORM_VIDEO = False if args.CNN_name != "hybrid_t-s" else True
    
    Mean = [0.485, 0.456, 0.406]
    Std = [0.229, 0.224, 0.225]
        
    print(f"Fold_{FOLD}")
    
    ckpt_paths = [
        Path("./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt"),
    ]

    train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
    val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
    test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
    
    traintransformT2 = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=0),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.GaussianBlur(kernel_size=3),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ).to(device)

    traintransform_frame = torch.nn.Sequential(
        transforms.RandomApply([
            transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
            transforms.RandomErasing(1.0, value="random")
            ], 0.3) #small noise
    ).to(device)
    
    testVal_transform = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ).to(device)
    
    testVal_transform_GSViT = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    ).to(device)
    
    key_frame_transform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    ).to(device)

    train_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
        train_set, 
        train_mode=True,
        num_windows=NUM_WIN,
        window_length=WIN_LENGTH, 
        transform=traintransformT2, 
        transform_frame=traintransform_frame, 
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform,
        load_key_frame_cache=CACHED_FEATURES,
        Fold_index=FOLD
    )
    val_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
        val_set, 
        train_mode=False,
        num_windows=NUM_WIN,
        window_length=WIN_LENGTH, 
        transform=testVal_transform,
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform,
        load_key_frame_cache=CACHED_FEATURES,
        Fold_index=FOLD
    )
    test_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
        test_set, 
        train_mode=False,
        num_windows=NUM_WIN,
        window_length=WIN_LENGTH, 
        transform=testVal_transform if args.CNN_name != "hybrid_t-s" else testVal_transform_GSViT,
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform,
        load_key_frame_cache=CACHED_FEATURES,
        Fold_index=FOLD,
        no_norm_video=NO_NORM_VIDEO
    )    
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    LogFileName = f"{args.Log_Name}" 
    
    checkPtCallback = [
        callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
        #callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
    ]
    
    trainer = L.Trainer(
        precision="32-true" if args.CNN_name == "gsvit" else "16-mixed",
        deterministic=True,
        accelerator="gpu",
        devices=[args.GPU],
        #devices=[0, 1], strategy="ddp",
        logger=TensorBoardLogger(save_dir=f"./{LogFileName}") if args.Phase == "train" else CSVLogger(save_dir=f"./{LogFileName}/Test", version=args.Log_version),
        log_every_n_steps=5,  
        callbacks=checkPtCallback,
        max_epochs=MAX_EPOCHS
    )
    
    match(args.Phase):
        case "cache_key_frame":
            from Models import RARP_NVB_DINO_MultiTask
            
            print (f"Load Export model for the FOLD #{FOLD}")
            Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(ckpt_paths[FOLD], map_location=device)
            Hybrid_TS.eval()
            
            namelist = ["TRAIN", "VAL", "TEST"]
            
            for _i, _s in enumerate([train_set, val_set, test_set]):
                print (f"[{namelist[_i]} Set] of FOLD # {FOLD}")
                key_frame_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
                    _s, 
                    key_frames=True,
                    key_frame_transform=key_frame_transform,
                    key_frame_only=True,
                )
                key_frameloader = DataLoader(
                    key_frame_dataset,
                    batch_size=BATCH_SIZE,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=WORKERS,
                    persistent_workers=WORKERS>0
                )
                
                print (f"[SAVE] caching Image features and Soft lables from Expert Model in FOLD #{FOLD}")
                with torch.no_grad():
                    for img, case_id in tqdm(iter(key_frameloader)):
                        B = img.shape[0]
                        img = img.to(device, dtype=torch.float)
                        
                        Soft_label, _, _ = Hybrid_TS(img)
                        Img_features = torch.cat((Hybrid_TS.last_conv_output_S, Hybrid_TS.last_conv_output_T), dim=1)
                        Img_features = torch.nn.functional.adaptive_avg_pool2d(Img_features, (1,1)).flatten(1) 
                                            
                        for i in range(B):
                            parent_path = next((r for r in _s if r.get("case") == case_id[i]), None)
                            parent_path = Path(parent_path["path"]).resolve().parent
                            parent_path = parent_path / "chache"
                            parent_path.mkdir(exist_ok=True)
                            np.savez((parent_path / f"F{FOLD}_{case_id[i]}.npz"), soft_label=Soft_label[i].cpu().numpy(), img_features=Img_features[i].cpu().numpy())
                
            print (f"[DONE] FOLD #{FOLD}")            
                
        case "train":
            Model = M.RARP_NVB_Multi_MOD_MIL(
                num_classes=1,
                temporal=args.Temp_Head,
                cnn_name=args.CNN_name,
                dropout=0.3,
                lr=1e-4, #3e-4,
                weight_decay=0.1, #0.05
                epochs=MAX_EPOCHS,
                pre_train=PRE_TRAIN,
                Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve()) if not CACHED_FEATURES else None,
                FOLD=FOLD,
                attn_entropy_target=0.4,
                attn_reg_warmup_epochs=5,
                attn_reg_weight=0.02
            )
            
            print(f"Model Used: {type(Model).__name__}")            
            print("Train Phase")
            trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
            trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
        case "eval_all":
            print("Evaluation Phase")
            rows = []
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
            for ckpFile in sorted(pathCkptFile.glob("*.ckpt")):
                print(ckpFile.name)
                #trainer.test(Model, dataloaders=test_loader, ckpt_path=ckpFile)
                #Model = M.RARP_NVB_DINO_MultiTask_A5_Video.load_from_checkpoint(ckpFile)
                
                hp_fiel = pathCkptFile.parent / "hparams.yaml"
                Model = M.RARP_NVB_Multi_MOD_MIL_TESTMode.load_from_checkpoint(ckpFile, map_location=device, hparams_file=hp_fiel)
                trainer.test(Model, dataloaders=test_loader)
                
                #temp = Calc_Eval_table(Model, test_loader, True, ckpFile.name)
                temp = Model._test_results
                rows += temp
            
            df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])        
            #df.style.highlight_max(color="red", axis=0)
            output_file = Path(f"./{LogFileName}/output.xlsx")             
            if not output_file.exists():
                df.to_excel(output_file, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
            else:
                with pd.ExcelWriter(output_file, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                    df.to_excel(writer, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
            print("[END] File saved ... ")
        case "cluster":
            ckpt_paths_MIL = [
                Path("./log_XT6/lightning_logs/version_0/checkpoints/RARP-epoch=23.ckpt"),
                Path("./log_XT6/lightning_logs/version_1/checkpoints/RARP-epoch=27.ckpt"),
                Path("./log_XT6/lightning_logs/version_2/checkpoints/RARP-epoch=20.ckpt"),
                Path("./log_XT6/lightning_logs/version_3/checkpoints/RARP-epoch=30.ckpt"),
                Path("./log_XT6/lightning_logs/version_4/checkpoints/RARP-epoch=29.ckpt"),
            ]
            
            hp_file = ckpt_paths_MIL[FOLD].parent.parent / "hparams.yaml"
            
            Model = M.RARP_NVB_Multi_MOD_MIL.load_from_checkpoint(ckpt_paths_MIL[FOLD], map_location=device, hparams_file=hp_file)
            Model = Model.to(device)
            Model.eval()
            encoder = None
            
            df, sil = extract_and_cluster_windows(
                test_loader, 
                encoder, 
                device, 
                random_state=0, 
                out_dir=f"./{args.CNN_name}_cluster_out_F{FOLD}", 
                vid_array=test_dataset.arrays, 
                dict_vid_array=test_dataset.case_index,
                num_samples=1,
                n_clusters=3,
                Hybrid_TS=Model.Hybrid_TS
            )