Newer
Older
RARP / MIL_video_rarp.py
@delAguila delAguila 27 days ago 24 KB Final Commit.
import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as PL
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models_video as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse


torch.backends.cuda.matmul.allow_tf32 = True  
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True



def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True
    
def rolling_mean_std(a, w):
        csum = np.cumsum(a, axis=0)
        csum = np.pad(csum, ((1,0),(0,0)), mode="constant")
        win_sum = csum[w:] - csum[:-w]
        mean = win_sum / float(w)
        sq = a**2
        csum_sq = np.cumsum(sq, axis=0)
        csum_sq = np.pad(csum_sq, ((1,0),(0,0)), mode="constant")
        win_sum_sq = csum_sq[w:] - csum_sq[:-w]
        var = (win_sum_sq / float(w)) - mean**2
        std = np.sqrt(np.maximum(var, 1e-12))
        return mean, std

def plot_tensor_analysis(x, fps=30, win=None, out_prefix="tensor_analysis"):
    """
    Visualize a tensor of shape [T, F] with:
      1) Time series per feature (raw + rolling mean ± std)
      2) Heatmap overview (per-feature normalized to [0,1])
      3) Distribution boxplots per feature

    Args:
        x (torch.Tensor): Input tensor of shape [T, F].
        fps (int): Frames per second (for x-axis in seconds).
        win (int or None): Rolling window size in frames. Default = fps.
        out_prefix (str): Prefix for saved file names.
    """
    # --- check input ---
    if not torch.is_tensor(x):
        raise ValueError("x must be a torch.Tensor")
    if x.ndim != 2:
        raise ValueError("x must have shape [T, F]")
        
    T, F = x.shape
    time_idx = np.arange(T)
    time_sec = time_idx / float(fps)

    arr = x.detach().cpu().numpy()

    # --- rolling mean/std ---
    if win is None:
        win = max(3, fps)  # default = ~1 second
    half = win // 2

    roll_mean, roll_std = rolling_mean_std(arr, win)
    roll_t = time_sec[half:half+len(roll_mean)]

    # ---------- 1) Time series ----------
    fig_ts, axes = plt.subplots(F, 1, figsize=(10, 2.5*F), sharex=True)
    if F == 1:
        axes = [axes]

    for f in range(F):
        ax = axes[f]
        ax.plot(time_sec, arr[:, f], alpha=0.35, linewidth=1.0, label=f'Feature {f}')
        ax.plot(roll_t, roll_mean[:, f], linewidth=2.0, label=f'Rolling mean (w={win})')
        ax.fill_between(roll_t,
                        roll_mean[:, f] - roll_std[:, f],
                        roll_mean[:, f] + roll_std[:, f],
                        alpha=0.2, label='±1 std (rolling)')
        ax.set_ylabel(f'Feature {f}')
        ax.grid(True, linestyle='--', alpha=0.3)
    axes[-1].set_xlabel('Time (s)')
    axes[0].legend(loc='upper right')
    fig_ts.suptitle('Per-feature time series with rolling mean ± std', y=1.02)
    fig_ts.tight_layout()
    fig_ts.savefig(f"output/{out_prefix}_time_series.png", dpi=200)

    # ---------- 2) Heatmap ----------
    fig_hm, ax = plt.subplots(figsize=(10, 2.8))
    arr_min = arr.min(axis=0, keepdims=True)
    arr_max = arr.max(axis=0, keepdims=True)
    arr_norm = (arr - arr_min) / (arr_max - arr_min + 1e-12)

    im = ax.imshow(arr_norm.T, aspect='auto', interpolation='nearest',
                   extent=[time_sec[0], time_sec[-1], F-0.5, -0.5])
    ax.set_yticks(np.arange(F))
    ax.set_yticklabels([f'Feat {f}' for f in range(F)])
    ax.set_xlabel('Time (s)')
    ax.set_title('Heatmap (per-feature normalized)')
    fig_hm.colorbar(im, ax=ax, fraction=0.025, pad=0.02)
    fig_hm.tight_layout()
    fig_hm.savefig(f"output/{out_prefix}_heatmap.png", dpi=200)

    # ---------- 3) Boxplots ----------
    fig_box, ax = plt.subplots(figsize=(7, 3.5))
    ax.boxplot([arr[:, f] for f in range(F)], showmeans=True)
    ax.set_xticklabels([f'Feat {f}' for f in range(F)])
    ax.set_ylabel('Value')
    ax.set_title('Distribution across time (boxplot per feature)')
    ax.grid(True, axis='y', linestyle='--', alpha=0.3)
    fig_box.tight_layout()
    fig_box.savefig(f"output/{out_prefix}_boxplots.png", dpi=200)

    print(f"Saved: {out_prefix}_time_series.png, {out_prefix}_heatmap.png, {out_prefix}_boxplots.png")

def Calc_Eval_table(
    TrainModel,
    TestDataLoadre:DataLoader, 
    Youden=False, 
    modelName="", 
):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []
     
    with torch.no_grad():
        for data, label, mask, key_frame in tqdm(iter(TestDataLoadre)):
                       
            data = data.to(device, dtype=torch.float32)
            key_frame = key_frame.to(device, dtype=torch.float32)
            mask = mask.to(device)
            label = label.to(device)
            
            #pred, *_ = TrainModel(data)
            pred, _ = TrainModel(data, key_frame, mask)
            pred = pred.flatten()
                
            Predictions.append(torch.sigmoid(pred))
            Labels.append(label)
      
    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels).int()

    #print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
        
    table = [
        ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
    ]

    if Youden:
        for i in range(2):
            aucCurve = torchmetrics.ROC("binary").to(device)
            fpr, tpr, thhols = aucCurve(Predictions, Labels)
            index = torch.argmax(tpr - fpr)
            th2 = (recall + specificty - 1).item()
            th2 = 0.5 if th2 <= 0 else th2
            th1 = thhols[index].item() if i == 0 else th2
            accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
            precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
            recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
            specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
            f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
            #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
            #cm2.update(Predictions, Labels)
            #_, ax = cm2.plot()
            #ax.set_title(f"NVB Classifier (th={th1:.4f})")
            table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
        

    return table

def ensure_list(x):
    # Converts tensor/list/tuple to python list
    if torch.is_tensor(x):
        return x.detach().cpu().tolist()
    if isinstance(x, (list, tuple)):
        return list(x)
    return [x]

def compute_uniform_starts(T: int, L: int, W: int):
    stride = (T - L) / max(W - 1, 1)
    starts = []
    for i in range(W):
        s = int(round(i * stride))
        s = min(s, T - L)
        starts.append(s)
    return starts

def extract_and_cluster_windows(
    test_loader,
    encoder,
    device,
    out_dir="./cluster_out",
    n_clusters=4,
    pca_dim=128,
    pca_vis_dim=2,
    random_state=505,
    # fallback params if win_start not in info
    T_total=1200,    # 20 min @ 1 fps
    L_win=None,      # will infer from data if None
):
    os.makedirs(out_dir, exist_ok=True)

    encoder = encoder.to(device).eval()

    features = []
    case_ids = []
    win_idx = []
    win_start = []
    y_video = []

    with torch.no_grad():
        for winds, label, _, info in test_loader:
            # winds: [B, N, L, C, H, W]
            B, N, L, C, H, W = winds.shape
            if L_win is None:
                L_win = L

            # middle frame per window
            mid = winds[:, :, L // 2]                 # [B, N, C, H, W]
            mid = mid.reshape(B * N, C, H, W)         # [B*N, C, H, W]

            f = encoder(mid.to(device, non_blocking=True))  # [B*N, D]
            f = torch.nn.functional.normalize(f, dim=1).cpu()
            features.append(f)

            # labels: video-level (B,) or (B,1). Repeat per window
            if torch.is_tensor(label):
                label_b = label.detach().cpu().view(B).tolist()
            else:
                label_b = list(label)

            # case_ids: list length B
            cids = info.get("case_id", info.get("case", None))
            cids = ensure_list(cids)

            # window indices
            # Preferred: info["win_idx"] is [B, N]
            if "win_idx" in info:
                widx = info["win_idx"]
                if torch.is_tensor(widx):
                    widx = widx.detach().cpu().view(B, N).tolist()
                # flatten by batch item
                for b in range(B):
                    case_ids.extend([cids[b]] * N)
                    win_idx.extend(widx[b])
                    y_video.extend([label_b[b]] * N)
            else:
                # fallback: use 0..N-1
                for b in range(B):
                    case_ids.extend([cids[b]] * N)
                    win_idx.extend(list(range(N)))
                    y_video.extend([label_b[b]] * N)

            # window start seconds
            # Preferred: info["win_start"] is [B, N]
            if "win_start" in info:
                ws = info["win_start"]
                if torch.is_tensor(ws):
                    ws = ws.detach().cpu().view(B, N).tolist()
                for b in range(B):
                    win_start.extend(ws[b])
            else:
                # fallback: compute from uniform starts, same for every case
                starts = compute_uniform_starts(T=T_total, L=L_win, W=N)
                # repeat for B cases
                win_start.extend(starts * B)

    X = torch.cat(features, dim=0).numpy()  # [TotalWindows, D]

    df = pd.DataFrame({
        "case_id": case_ids,
        "win_idx": win_idx,
        "win_start_sec": win_start,
        "label_nvb": y_video
    })
    df["win_start_min"] = round (df["win_start_sec"] / 60.0, 4)

    # PCA for clustering
    X_pca = PCA(n_components=min(pca_dim, X.shape[1]), random_state=random_state).fit_transform(X)

    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto")
    df["cluster"] = kmeans.fit_predict(X_pca)

    sil = silhouette_score(X_pca, df["cluster"].values)
    print(f"Silhouette (PCA-{min(pca_dim, X.shape[1])}, K={n_clusters}): {sil:.6f}")

    # Save table
    csv_path = os.path.join(out_dir, "windows_clusters.csv")
    df.to_csv(csv_path, index=False)
    print(f"Saved: {csv_path}")

    # Plot 1: PCA-2D scatter colored by cluster
    X_vis = PCA(n_components=pca_vis_dim, random_state=random_state).fit_transform(X)
    plt.figure(figsize=(7, 6))
    plt.scatter(X_vis[:, 0], X_vis[:, 1], c=df["cluster"].values, s=6, alpha=0.5)
    plt.title(f"Window Embeddings (PCA-{pca_vis_dim}) — KMeans K={n_clusters}")
    plt.xlabel("PC1"); plt.ylabel("PC2")
    plt.grid(True, alpha=0.2)
    pca_path = os.path.join(out_dir, "pca2d_clusters.png")
    plt.savefig(pca_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved: {pca_path}")

    # Plot 2: cluster timeline for a few cases
    # (Pick some cases; you can also pass a list explicitly)
    unique_cases = df["case_id"].unique().tolist()
    show_cases = unique_cases[:min(6, len(unique_cases))]

    _, axes = plt.subplots(len(show_cases), 1, figsize=(10, 2.0 * len(show_cases)), sharex=True)
    if len(show_cases) == 1:
        axes = [axes]

    for ax, cid in zip(axes, show_cases):
        sub = df[df["case_id"] == cid].sort_values("win_start_sec")
        ax.scatter(sub["win_start_min"], sub["cluster"], s=20)
        ax.set_ylabel("Cluster")
        ax.set_title(f"Case {cid} — cluster vs time")
        ax.grid(True)

    axes[-1].set_xlabel("Time (minutes)")
    timeline_path = os.path.join(out_dir, "cluster_timelines.png")
    plt.tight_layout()
    plt.savefig(timeline_path, dpi=200, bbox_inches="tight")
    plt.close()
    print(f"Saved: {timeline_path}")

    return df, sil
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'", required=True)
    parser.add_argument("--Fold", type=int, default=0)
    parser.add_argument("-lv","--Log_version", type=int, default=None)
    parser.add_argument("--Workers", type=int, default=0)
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("--CNN_name", type=str, default=None, )
    parser.add_argument("--Temp_Head", type=str, default=None, )
    parser.add_argument("-me", "--maxEpochs", type=int, default=None)
    parser.add_argument("-b", "--Batch_size", type=int, default=8)
    parser.add_argument("--GPU", type=int, default=0)
    parser.add_argument("--pre_train", type=int, default=0)
    parser.add_argument("-k", "--k_windows", type=int, default=1)
    parser.add_argument("--Window_Size", type=int, default=64)
    parser.add_argument("--Num_Window", type=int, default=8)
    parser.add_argument("--cached_features", type=bool, default=False)
    
    args = parser.parse_args()
    
    setup_seed(2023)
    device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")

    df = pd.read_csv("../dataset/Dataset_RARP_video/dataset_videos_folds.csv")

    FOLD = args.Fold
    WORKERS = args.Workers
    BATCH_SIZE = args.Batch_size
    MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
    PRE_TRAIN = args.pre_train != 0 
    K_WIN = args.k_windows
    KEY_FRAME = True if args.Phase != "cluster" else False
    WIN_LENGTH = args.Window_Size
    NUM_WIN = args.Num_Window
    CACHED_FEATURES  = args.cached_features
    
    Mean = [0.485, 0.456, 0.406]
    Std = [0.229, 0.224, 0.225]
        
    print(f"Fold_{FOLD}")
    
    ckpt_paths = [
        Path("./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt"),
    ]

    train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
    val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
    test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
    
    traintransformT2 = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=0),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.GaussianBlur(kernel_size=3),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ).to(device)

    traintransform_frame = torch.nn.Sequential(
        transforms.RandomApply([
            transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
            transforms.RandomErasing(1.0, value="random")
            ], 0.3) #small noise
    ).to(device)
    
    testVal_transform = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ).to(device)
    
    key_frame_transform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    ).to(device)

    train_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
        train_set, 
        train_mode=True,
        num_windows=NUM_WIN,
        window_length=WIN_LENGTH, 
        transform=traintransformT2, 
        transform_frame=traintransform_frame, 
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform,
        load_key_frame_cache=CACHED_FEATURES,
        Fold_index=FOLD
    )
    val_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
        val_set, 
        train_mode=False,
        num_windows=NUM_WIN,
        window_length=WIN_LENGTH, 
        transform=testVal_transform,
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform,
        load_key_frame_cache=CACHED_FEATURES,
        Fold_index=FOLD
    )
    test_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
        test_set, 
        train_mode=False,
        num_windows=NUM_WIN,
        window_length=WIN_LENGTH, 
        transform=testVal_transform,
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform,
        load_key_frame_cache=CACHED_FEATURES,
        Fold_index=FOLD
    )    
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    LogFileName = f"{args.Log_Name}" 
    
    checkPtCallback = [
        callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
        #callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
    ]
    
    trainer = PL.Trainer(
        precision="32-true" if args.CNN_name == "gsvit" else "16-mixed",
        deterministic=True,
        accelerator="gpu",
        devices=[args.GPU],
        #devices=[0, 1], strategy="ddp",
        logger=TensorBoardLogger(save_dir=f"./{LogFileName}") if args.Phase == "train" else CSVLogger(save_dir=f"./{LogFileName}/Test", version=args.Log_version),
        log_every_n_steps=5,  
        callbacks=checkPtCallback,
        max_epochs=MAX_EPOCHS
    )
    
    match(args.Phase):
        case "cache_key_frame":
            from Models import RARP_NVB_DINO_MultiTask
            
            print (f"Load Export model for the FOLD #{FOLD}")
            Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(ckpt_paths[FOLD], map_location=device)
            Hybrid_TS.eval()
            
            namelist = ["TRAIN", "VAL", "TEST"]
            
            for _i, _s in enumerate([train_set, val_set, test_set]):
                print (f"[{namelist[_i]} Set] of FOLD # {FOLD}")
                key_frame_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
                    _s, 
                    key_frames=True,
                    key_frame_transform=key_frame_transform,
                    key_frame_only=True,
                )
                key_frameloader = DataLoader(
                    key_frame_dataset,
                    batch_size=BATCH_SIZE,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=WORKERS,
                    persistent_workers=WORKERS>0
                )
                
                print (f"[SAVE] caching Image features and Soft lables from Expert Model in FOLD #{FOLD}")
                with torch.no_grad():
                    for img, case_id in tqdm(iter(key_frameloader)):
                        B = img.shape[0]
                        img = img.to(device, dtype=torch.float)
                        
                        Soft_label, _, _ = Hybrid_TS(img)
                        Img_features = torch.cat((Hybrid_TS.last_conv_output_S, Hybrid_TS.last_conv_output_T), dim=1)
                        Img_features = torch.nn.functional.adaptive_avg_pool2d(Img_features, (1,1)).flatten(1) 
                                            
                        for i in range(B):
                            parent_path = next((r for r in _s if r.get("case") == case_id[i]), None)
                            parent_path = Path(parent_path["path"]).resolve().parent
                            parent_path = parent_path / "chache"
                            parent_path.mkdir(exist_ok=True)
                            np.savez((parent_path / f"F{FOLD}_{case_id[i]}.npz"), soft_label=Soft_label[i].cpu().numpy(), img_features=Img_features[i].cpu().numpy())
                
            print (f"[DONE] FOLD #{FOLD}")            
                
        case "train":
            Model = M.RARP_NVB_Multi_MOD_MIL(
                num_classes=1,
                temporal=args.Temp_Head,
                cnn_name=args.CNN_name,
                dropout=0.3,
                lr=1e-4, #3e-4,
                weight_decay=0.1, #0.05
                epochs=MAX_EPOCHS,
                pre_train=PRE_TRAIN,
                Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve()) if not CACHED_FEATURES else None,
                FOLD=FOLD,
                attn_entropy_target=0.4,
                attn_reg_warmup_epochs=5,
                attn_reg_weight=0.02
            )
            
            print(f"Model Used: {type(Model).__name__}")            
            print("Train Phase")
            trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
            trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
        case "eval_all":
            print("Evaluation Phase")
            rows = []
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
            for ckpFile in sorted(pathCkptFile.glob("*.ckpt")):
                print(ckpFile.name)
                #trainer.test(Model, dataloaders=test_loader, ckpt_path=ckpFile)
                #Model = M.RARP_NVB_DINO_MultiTask_A5_Video.load_from_checkpoint(ckpFile)
                
                hp_fiel = pathCkptFile.parent / "hparams.yaml"
                Model = M.RARP_NVB_Multi_MOD_MIL_TESTMode.load_from_checkpoint(ckpFile, map_location=device, hparams_file=hp_fiel)
                trainer.test(Model, dataloaders=test_loader)
                
                #temp = Calc_Eval_table(Model, test_loader, True, ckpFile.name)
                temp = Model._test_results
                rows += temp
            
            df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])        
            #df.style.highlight_max(color="red", axis=0)
            output_file = Path(f"./{LogFileName}/output.xlsx")             
            if not output_file.exists():
                df.to_excel(output_file, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
            else:
                with pd.ExcelWriter(output_file, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                    df.to_excel(writer, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
            print("[END] File saved ... ")
        case "cluster":
            os.environ["OMP_NUM_THREADS"] = "2"
            
            from sklearn.decomposition import PCA
            from sklearn.cluster import KMeans
            from sklearn.metrics import silhouette_score
            
            Model = M.RARP_NVB_Multi_MOD_MIL(
                num_classes=1,
                temporal=args.Temp_Head,
                cnn_name=args.CNN_name,
                dropout=0.3,
                lr=1e-4, #3e-4,
                weight_decay=0.1, #0.05
                epochs=MAX_EPOCHS,
                pre_train=PRE_TRAIN,
                Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve()) if not CACHED_FEATURES else None,
                FOLD=FOLD,
                attn_entropy_target=0.4,
                attn_reg_warmup_epochs=5,
                attn_reg_weight=0.02
            )
            Model = Model.to(device)
            Model.eval()
            encoder = Model.cnn
            
            df, sil = extract_and_cluster_windows(test_loader, encoder, device, random_state=505)