Newer
Older
RARP_server / video_rarp.py
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse
import decord

#decord._ffi.base.set_num_threads(4) 

torch.backends.cuda.matmul.allow_tf32 = True  
torch.set_float32_matmul_precision('high')
#torch.backends.cudnn.deterministic = True

Mean = [0.485, 0.456, 0.406]
Std = [0.229, 0.224, 0.225]

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    #torch.backends.cudnn.deterministic = True
    
def rolling_mean_std(a, w):
        csum = np.cumsum(a, axis=0)
        csum = np.pad(csum, ((1,0),(0,0)), mode="constant")
        win_sum = csum[w:] - csum[:-w]
        mean = win_sum / float(w)
        sq = a**2
        csum_sq = np.cumsum(sq, axis=0)
        csum_sq = np.pad(csum_sq, ((1,0),(0,0)), mode="constant")
        win_sum_sq = csum_sq[w:] - csum_sq[:-w]
        var = (win_sum_sq / float(w)) - mean**2
        std = np.sqrt(np.maximum(var, 1e-12))
        return mean, std

def plot_tensor_analysis(x, fps=30, win=None, out_prefix="tensor_analysis"):
    """
    Visualize a tensor of shape [T, F] with:
      1) Time series per feature (raw + rolling mean ± std)
      2) Heatmap overview (per-feature normalized to [0,1])
      3) Distribution boxplots per feature

    Args:
        x (torch.Tensor): Input tensor of shape [T, F].
        fps (int): Frames per second (for x-axis in seconds).
        win (int or None): Rolling window size in frames. Default = fps.
        out_prefix (str): Prefix for saved file names.
    """
    # --- check input ---
    if not torch.is_tensor(x):
        raise ValueError("x must be a torch.Tensor")
    if x.ndim != 2:
        raise ValueError("x must have shape [T, F]")
        
    T, F = x.shape
    time_idx = np.arange(T)
    time_sec = time_idx / float(fps)

    arr = x.detach().cpu().numpy()

    # --- rolling mean/std ---
    if win is None:
        win = max(3, fps)  # default = ~1 second
    half = win // 2

    roll_mean, roll_std = rolling_mean_std(arr, win)
    roll_t = time_sec[half:half+len(roll_mean)]

    # ---------- 1) Time series ----------
    fig_ts, axes = plt.subplots(F, 1, figsize=(10, 2.5*F), sharex=True)
    if F == 1:
        axes = [axes]

    for f in range(F):
        ax = axes[f]
        ax.plot(time_sec, arr[:, f], alpha=0.35, linewidth=1.0, label=f'Feature {f}')
        ax.plot(roll_t, roll_mean[:, f], linewidth=2.0, label=f'Rolling mean (w={win})')
        ax.fill_between(roll_t,
                        roll_mean[:, f] - roll_std[:, f],
                        roll_mean[:, f] + roll_std[:, f],
                        alpha=0.2, label='±1 std (rolling)')
        ax.set_ylabel(f'Feature {f}')
        ax.grid(True, linestyle='--', alpha=0.3)
    axes[-1].set_xlabel('Time (s)')
    axes[0].legend(loc='upper right')
    fig_ts.suptitle('Per-feature time series with rolling mean ± std', y=1.02)
    fig_ts.tight_layout()
    fig_ts.savefig(f"output/{out_prefix}_time_series.png", dpi=200)

    # ---------- 2) Heatmap ----------
    fig_hm, ax = plt.subplots(figsize=(10, 2.8))
    arr_min = arr.min(axis=0, keepdims=True)
    arr_max = arr.max(axis=0, keepdims=True)
    arr_norm = (arr - arr_min) / (arr_max - arr_min + 1e-12)

    im = ax.imshow(arr_norm.T, aspect='auto', interpolation='nearest',
                   extent=[time_sec[0], time_sec[-1], F-0.5, -0.5])
    ax.set_yticks(np.arange(F))
    ax.set_yticklabels([f'Feat {f}' for f in range(F)])
    ax.set_xlabel('Time (s)')
    ax.set_title('Heatmap (per-feature normalized)')
    fig_hm.colorbar(im, ax=ax, fraction=0.025, pad=0.02)
    fig_hm.tight_layout()
    fig_hm.savefig(f"output/{out_prefix}_heatmap.png", dpi=200)

    # ---------- 3) Boxplots ----------
    fig_box, ax = plt.subplots(figsize=(7, 3.5))
    ax.boxplot([arr[:, f] for f in range(F)], showmeans=True)
    ax.set_xticklabels([f'Feat {f}' for f in range(F)])
    ax.set_ylabel('Value')
    ax.set_title('Distribution across time (boxplot per feature)')
    ax.grid(True, axis='y', linestyle='--', alpha=0.3)
    fig_box.tight_layout()
    fig_box.savefig(f"output/{out_prefix}_boxplots.png", dpi=200)

    print(f"Saved: {out_prefix}_time_series.png, {out_prefix}_heatmap.png, {out_prefix}_boxplots.png")

def Calc_Eval_table(
    TrainModel:M.RARP_NVB_Model,
    TestDataLoadre:DataLoader, 
    Youden=False, 
    modelName="", 
):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []
     
    with torch.no_grad():
        for data, label in tqdm(iter(TestDataLoadre)):
            
            data = data.float().to(device)
            label = label.to(device)
            
            #pred, *_ = TrainModel(data)
            pred = TrainModel(data)
            pred = pred.flatten()
                
            Predictions.append(torch.sigmoid(pred))
            Labels.append(label)
      
    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels).int()

    #print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
        
    table = [
        ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
    ]

    if Youden:
        for i in range(2):
            aucCurve = torchmetrics.ROC("binary").to(device)
            fpr, tpr, thhols = aucCurve(Predictions, Labels)
            index = torch.argmax(tpr - fpr)
            th2 = (recall + specificty - 1).item()
            th2 = 0.5 if th2 <= 0 else th2
            th1 = thhols[index].item() if i == 0 else th2
            accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
            precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
            recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
            specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
            f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
            #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
            #cm2.update(Predictions, Labels)
            #_, ax = cm2.plot()
            #ax.set_title(f"NVB Classifier (th={th1:.4f})")
            table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
        

    return table
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
    parser.add_argument("--Fold", type=int, default=0)
    parser.add_argument("-lv","--Log_version", type=int, default=None)
    parser.add_argument("--Workers", type=int, default=0)
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("--Head", type=int, default=None)
    parser.add_argument("-me", "--maxEpochs", type=int, default=None)
    parser.add_argument("-b", "--Batch_size", type=int, default=8)
    parser.add_argument("--Video_chunks", type=int, default=50)
    parser.add_argument("--GPU", type=int, default=0)
    
    
    args = parser.parse_args()
    
    setup_seed(2023)
    device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")

    Mean = torch.tensor([30.38144216, 42.03988769, 97.8896116]).view(1,3,1,1)
    Std = torch.tensor([40.63141752, 44.26910074, 50.29294373]).view(1,3,1,1)

    df = pd.read_csv("./Dataset_RARP_video/dataset_videos_folds.csv")

    FOLD = args.Fold
    WORKERS = args.Workers
    BATCH_SIZE = args.Batch_size
    MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
        
    print(f"Fold_{FOLD}")

    train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
    val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
    test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
    
    traintransformT2 = torch.nn.Sequential(
        transforms.RandomErasing(0.6, value="random"),
        transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=5),
        transforms.RandomApply([transforms.GaussianBlur(5)], 0.5),
        transforms.RandomHorizontalFlip(0.3),
    ).to(device)

    
    #train_dataset = Loaders.RARP_Video_Dataset(train_set, (224, 224), (139, 0, 360, 360), decode_resize=(640, 360), mean=Mean, std=Std, transform=traintransformT2)
    train_dataset = Loaders.RARP_Video_Dataset(train_set, (224, 224), (139, 0, 360, 360), decode_resize=(640, 360), mean=Mean, std=Std, transform=None)
    #train_dataset = Loaders.RARP_Video_Dataset(train_set, (224, 224), (139, 0, 360, 360), decode_resize=(640, 360), mean=Mean, std=Std, transform=traintransformT2, transform_frame=False)
    val_dataset = Loaders.RARP_Video_Dataset(val_set, (224, 224), (139, 0, 360, 360), decode_resize=(640, 360), mean=Mean, std=Std)
    test_dataset = Loaders.RARP_Video_Dataset(test_set, (224, 224), (139, 0, 360, 360), decode_resize=(640, 360), mean=Mean, std=Std)
    

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    ckpt_paths = [
        Path("./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt"),
    ]

    if args.Head in [1, 2, 3, 4]:
        Model = M.RARP_NVB_DINO_MultiTask_A5_Video(
            base_model_path=str(ckpt_paths[FOLD].resolve()),
            lr=3e-4,
            wd=1e-4,
            head_type=args.Head,
            chunks_loading=args.Video_chunks
        )
    elif args.Head == 5:
        Model = M.RARP_NVB_DINO_MultiTask_A6_Video(head_type=1,chunks_loading=args.Video_chunks)
    else:
        Model = M.RARP_NVB_VIDEO_3D_ResNet(chunks_loading=args.Video_chunks, str_path="Pre-trainRN50.pth")
    
    print(f"Model Used: {type(Model).__name__}")
    LogFileName = f"{args.Log_Name}" 
    
    checkPtCallback = [
        callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
        callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
    ]
    
    trainer = L.Trainer(
        precision="16-mixed",
        deterministic="warn",
        accelerator="gpu",
        devices=[args.GPU],
        logger=TensorBoardLogger(save_dir=f"./{LogFileName}") if args.Phase == "train" else CSVLogger(save_dir=f"./{LogFileName}/Test", version=args.Log_version),
        log_every_n_steps=5,  
        callbacks=checkPtCallback,
        max_epochs=MAX_EPOCHS
    )
    
    match(args.Phase):
        case "train":
            print("Train Phase")
            trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
            trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
        case "eval_all":
            print("Evaluation Phase")
            rows = []
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
            for ckpFile in sorted(pathCkptFile.glob("*.ckpt")):
                print(ckpFile.name)
                #trainer.test(Model, dataloaders=test_loader, ckpt_path=ckpFile)
                #Model = M.RARP_NVB_DINO_MultiTask_A5_Video.load_from_checkpoint(ckpFile)
                Model = M.RARP_NVB_VIDEO_3D_ResNet.load_from_checkpoint(ckpFile)
                
                temp = Calc_Eval_table(
                        Model, 
                        test_loader, 
                        True, 
                        ckpFile.name
                    )
                rows += temp
            
            df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])        
            df.style.highlight_max(color="red", axis=0)
            print(df)