Newer
Older
RARP / MIL_video_rarp.py
@delAguila delAguila on 8 Jan 18 KB GSViT
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models_video as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse


torch.backends.cuda.matmul.allow_tf32 = True  
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True



def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True
    
def rolling_mean_std(a, w):
        csum = np.cumsum(a, axis=0)
        csum = np.pad(csum, ((1,0),(0,0)), mode="constant")
        win_sum = csum[w:] - csum[:-w]
        mean = win_sum / float(w)
        sq = a**2
        csum_sq = np.cumsum(sq, axis=0)
        csum_sq = np.pad(csum_sq, ((1,0),(0,0)), mode="constant")
        win_sum_sq = csum_sq[w:] - csum_sq[:-w]
        var = (win_sum_sq / float(w)) - mean**2
        std = np.sqrt(np.maximum(var, 1e-12))
        return mean, std

def plot_tensor_analysis(x, fps=30, win=None, out_prefix="tensor_analysis"):
    """
    Visualize a tensor of shape [T, F] with:
      1) Time series per feature (raw + rolling mean ± std)
      2) Heatmap overview (per-feature normalized to [0,1])
      3) Distribution boxplots per feature

    Args:
        x (torch.Tensor): Input tensor of shape [T, F].
        fps (int): Frames per second (for x-axis in seconds).
        win (int or None): Rolling window size in frames. Default = fps.
        out_prefix (str): Prefix for saved file names.
    """
    # --- check input ---
    if not torch.is_tensor(x):
        raise ValueError("x must be a torch.Tensor")
    if x.ndim != 2:
        raise ValueError("x must have shape [T, F]")
        
    T, F = x.shape
    time_idx = np.arange(T)
    time_sec = time_idx / float(fps)

    arr = x.detach().cpu().numpy()

    # --- rolling mean/std ---
    if win is None:
        win = max(3, fps)  # default = ~1 second
    half = win // 2

    roll_mean, roll_std = rolling_mean_std(arr, win)
    roll_t = time_sec[half:half+len(roll_mean)]

    # ---------- 1) Time series ----------
    fig_ts, axes = plt.subplots(F, 1, figsize=(10, 2.5*F), sharex=True)
    if F == 1:
        axes = [axes]

    for f in range(F):
        ax = axes[f]
        ax.plot(time_sec, arr[:, f], alpha=0.35, linewidth=1.0, label=f'Feature {f}')
        ax.plot(roll_t, roll_mean[:, f], linewidth=2.0, label=f'Rolling mean (w={win})')
        ax.fill_between(roll_t,
                        roll_mean[:, f] - roll_std[:, f],
                        roll_mean[:, f] + roll_std[:, f],
                        alpha=0.2, label='±1 std (rolling)')
        ax.set_ylabel(f'Feature {f}')
        ax.grid(True, linestyle='--', alpha=0.3)
    axes[-1].set_xlabel('Time (s)')
    axes[0].legend(loc='upper right')
    fig_ts.suptitle('Per-feature time series with rolling mean ± std', y=1.02)
    fig_ts.tight_layout()
    fig_ts.savefig(f"output/{out_prefix}_time_series.png", dpi=200)

    # ---------- 2) Heatmap ----------
    fig_hm, ax = plt.subplots(figsize=(10, 2.8))
    arr_min = arr.min(axis=0, keepdims=True)
    arr_max = arr.max(axis=0, keepdims=True)
    arr_norm = (arr - arr_min) / (arr_max - arr_min + 1e-12)

    im = ax.imshow(arr_norm.T, aspect='auto', interpolation='nearest',
                   extent=[time_sec[0], time_sec[-1], F-0.5, -0.5])
    ax.set_yticks(np.arange(F))
    ax.set_yticklabels([f'Feat {f}' for f in range(F)])
    ax.set_xlabel('Time (s)')
    ax.set_title('Heatmap (per-feature normalized)')
    fig_hm.colorbar(im, ax=ax, fraction=0.025, pad=0.02)
    fig_hm.tight_layout()
    fig_hm.savefig(f"output/{out_prefix}_heatmap.png", dpi=200)

    # ---------- 3) Boxplots ----------
    fig_box, ax = plt.subplots(figsize=(7, 3.5))
    ax.boxplot([arr[:, f] for f in range(F)], showmeans=True)
    ax.set_xticklabels([f'Feat {f}' for f in range(F)])
    ax.set_ylabel('Value')
    ax.set_title('Distribution across time (boxplot per feature)')
    ax.grid(True, axis='y', linestyle='--', alpha=0.3)
    fig_box.tight_layout()
    fig_box.savefig(f"output/{out_prefix}_boxplots.png", dpi=200)

    print(f"Saved: {out_prefix}_time_series.png, {out_prefix}_heatmap.png, {out_prefix}_boxplots.png")

def Calc_Eval_table(
    TrainModel,
    TestDataLoadre:DataLoader, 
    Youden=False, 
    modelName="", 
):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []
     
    with torch.no_grad():
        for data, label, mask, key_frame in tqdm(iter(TestDataLoadre)):
                       
            data = data.to(device, dtype=torch.float32)
            key_frame = key_frame.to(device, dtype=torch.float32)
            mask = mask.to(device)
            label = label.to(device)
            
            #pred, *_ = TrainModel(data)
            pred, _ = TrainModel(data, key_frame, mask)
            pred = pred.flatten()
                
            Predictions.append(torch.sigmoid(pred))
            Labels.append(label)
      
    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels).int()

    #print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
        
    table = [
        ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
    ]

    if Youden:
        for i in range(2):
            aucCurve = torchmetrics.ROC("binary").to(device)
            fpr, tpr, thhols = aucCurve(Predictions, Labels)
            index = torch.argmax(tpr - fpr)
            th2 = (recall + specificty - 1).item()
            th2 = 0.5 if th2 <= 0 else th2
            th1 = thhols[index].item() if i == 0 else th2
            accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
            precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
            recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
            specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
            f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
            #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
            #cm2.update(Predictions, Labels)
            #_, ax = cm2.plot()
            #ax.set_title(f"NVB Classifier (th={th1:.4f})")
            table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
        

    return table
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'", required=True)
    parser.add_argument("--Fold", type=int, default=0)
    parser.add_argument("-lv","--Log_version", type=int, default=None)
    parser.add_argument("--Workers", type=int, default=0)
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("--CNN_name", type=str, default=None, )
    parser.add_argument("--Temp_Head", type=str, default=None, )
    parser.add_argument("-me", "--maxEpochs", type=int, default=None)
    parser.add_argument("-b", "--Batch_size", type=int, default=8)
    parser.add_argument("--GPU", type=int, default=0)
    parser.add_argument("--pre_train", type=int, default=0)
    parser.add_argument("-k", "--k_windows", type=int, default=1)
    parser.add_argument("--Window_Size", type=int, default=64)
    parser.add_argument("--Num_Window", type=int, default=8)
    parser.add_argument("--cached_features", type=bool, default=False)
    
    args = parser.parse_args()
    
    setup_seed(2023)
    device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")

    df = pd.read_csv("../dataset/Dataset_RARP_video/dataset_videos_folds.csv")

    FOLD = args.Fold
    WORKERS = args.Workers
    BATCH_SIZE = args.Batch_size
    MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
    PRE_TRAIN = args.pre_train != 0 
    K_WIN = args.k_windows
    KEY_FRAME = True
    WIN_LENGTH = args.Window_Size
    NUM_WIN = args.Num_Window
    CACHED_FEATURES  = args.cached_features
    
    Mean = [0.485, 0.456, 0.406]
    Std = [0.229, 0.224, 0.225]
        
    print(f"Fold_{FOLD}")
    
    ckpt_paths = [
        Path("./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt"),
    ]

    train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
    val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
    test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
    
    traintransformT2 = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=0),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.GaussianBlur(kernel_size=3),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ).to(device)

    traintransform_frame = torch.nn.Sequential(
        transforms.RandomApply([
            transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
            transforms.RandomErasing(1.0, value="random")
            ], 0.3) #small noise
    ).to(device)
    
    testVal_transform = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ).to(device)
    
    key_frame_transform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    ).to(device)

    train_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
        train_set, 
        train_mode=True,
        num_windows=NUM_WIN,
        window_length=WIN_LENGTH, 
        transform=traintransformT2, 
        transform_frame=traintransform_frame, 
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform,
        load_key_frame_cache=CACHED_FEATURES,
        Fold_index=FOLD
    )
    val_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
        val_set, 
        train_mode=False,
        num_windows=NUM_WIN,
        window_length=WIN_LENGTH, 
        transform=testVal_transform,
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform,
        load_key_frame_cache=CACHED_FEATURES,
        Fold_index=FOLD
    )
    test_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
        test_set, 
        train_mode=False,
        num_windows=NUM_WIN,
        window_length=WIN_LENGTH, 
        transform=testVal_transform,
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform,
        load_key_frame_cache=CACHED_FEATURES,
        Fold_index=FOLD
    )    
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        prefetch_factor=1 if WORKERS>0 else None,
        #pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        #pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        #pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    LogFileName = f"{args.Log_Name}" 
    
    checkPtCallback = [
        callbk.ModelCheckpoint(monitor='val_wind_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
        #callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
    ]
    
    trainer = L.Trainer(
        precision="32-true" if args.CNN_name == "gsvit" else "16-mixed",
        deterministic=True,
        accelerator="gpu",
        devices=[args.GPU],
        #devices=[0, 1], strategy="ddp",
        logger=TensorBoardLogger(save_dir=f"./{LogFileName}") if args.Phase == "train" else CSVLogger(save_dir=f"./{LogFileName}/Test", version=args.Log_version),
        log_every_n_steps=5,  
        callbacks=checkPtCallback,
        max_epochs=MAX_EPOCHS
    )
    
    match(args.Phase):
        case "cache_key_frame":
            from Models import RARP_NVB_DINO_MultiTask
            
            print (f"Load Export model for the FOLD #{FOLD}")
            Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(ckpt_paths[FOLD], map_location=device)
            Hybrid_TS.eval()
            
            namelist = ["TRAIN", "VAL", "TEST"]
            
            for _i, _s in enumerate([train_set, val_set, test_set]):
                print (f"[{namelist[_i]} Set] of FOLD # {FOLD}")
                key_frame_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
                    _s, 
                    key_frames=True,
                    key_frame_transform=key_frame_transform,
                    key_frame_only=True,
                )
                key_frameloader = DataLoader(
                    key_frame_dataset,
                    batch_size=BATCH_SIZE,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=WORKERS,
                    persistent_workers=WORKERS>0
                )
                
                print (f"[SAVE] caching Image features and Soft lables from Expert Model in FOLD #{FOLD}")
                with torch.no_grad():
                    for img, case_id in tqdm(iter(key_frameloader)):
                        B = img.shape[0]
                        img = img.to(device, dtype=torch.float)
                        
                        Soft_label, _, _ = Hybrid_TS(img)
                        Img_features = torch.cat((Hybrid_TS.last_conv_output_S, Hybrid_TS.last_conv_output_T), dim=1)
                        Img_features = torch.nn.functional.adaptive_avg_pool2d(Img_features, (1,1)).flatten(1) 
                                            
                        for i in range(B):
                            parent_path = next((r for r in _s if r.get("case") == case_id[i]), None)
                            parent_path = Path(parent_path["path"]).resolve().parent
                            parent_path = parent_path / "chache"
                            parent_path.mkdir(exist_ok=True)
                            np.savez((parent_path / f"F{FOLD}_{case_id[i]}.npz"), soft_label=Soft_label[i].cpu().numpy(), img_features=Img_features[i].cpu().numpy())
                
            print (f"[DONE] FOLD #{FOLD}")            
                
        case "train":
            Model = M.RARP_NVB_Multi_MOD_MIL(
                num_classes=1,
                temporal=args.Temp_Head,
                cnn_name=args.CNN_name,
                dropout=0.3,
                lr=1e-4, #3e-4,
                weight_decay=0.1, #0.05
                epochs=MAX_EPOCHS,
                pre_train=PRE_TRAIN,
                Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve()) if not CACHED_FEATURES else None
            )
            
            print(f"Model Used: {type(Model).__name__}")            
            print("Train Phase")
            trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
            trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
        case "eval_all":
            print("Evaluation Phase")
            rows = []
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
            for ckpFile in sorted(pathCkptFile.glob("*.ckpt")):
                print(ckpFile.name)
                #trainer.test(Model, dataloaders=test_loader, ckpt_path=ckpFile)
                #Model = M.RARP_NVB_DINO_MultiTask_A5_Video.load_from_checkpoint(ckpFile)
                
                hp_fiel = pathCkptFile.parent / "hparams.yaml"
                Model = M.RARP_NVB_Multi_MOD_MIL_TESTMode.load_from_checkpoint(ckpFile, map_location=device, hparams_file=hp_fiel)
                trainer.test(Model, dataloaders=test_loader)
                
                #temp = Calc_Eval_table(Model, test_loader, True, ckpFile.name)
                temp = Model._test_results
                rows += temp
            
            df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])        
            #df.style.highlight_max(color="red", axis=0)
            output_file = Path(f"./{LogFileName}/output.xlsx")             
            if not output_file.exists():
                df.to_excel(output_file, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
            else:
                with pd.ExcelWriter(output_file, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                    df.to_excel(writer, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
            print("[END] File saved ... ")