Newer
Older
RARP / win_video_rarp.py
@delAguila delAguila 27 days ago 15 KB Final Commit.
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models_video as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse


torch.backends.cuda.matmul.allow_tf32 = True  
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True



def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True
    
def rolling_mean_std(a, w):
        csum = np.cumsum(a, axis=0)
        csum = np.pad(csum, ((1,0),(0,0)), mode="constant")
        win_sum = csum[w:] - csum[:-w]
        mean = win_sum / float(w)
        sq = a**2
        csum_sq = np.cumsum(sq, axis=0)
        csum_sq = np.pad(csum_sq, ((1,0),(0,0)), mode="constant")
        win_sum_sq = csum_sq[w:] - csum_sq[:-w]
        var = (win_sum_sq / float(w)) - mean**2
        std = np.sqrt(np.maximum(var, 1e-12))
        return mean, std

def plot_tensor_analysis(x, fps=30, win=None, out_prefix="tensor_analysis"):
    """
    Visualize a tensor of shape [T, F] with:
      1) Time series per feature (raw + rolling mean ± std)
      2) Heatmap overview (per-feature normalized to [0,1])
      3) Distribution boxplots per feature

    Args:
        x (torch.Tensor): Input tensor of shape [T, F].
        fps (int): Frames per second (for x-axis in seconds).
        win (int or None): Rolling window size in frames. Default = fps.
        out_prefix (str): Prefix for saved file names.
    """
    # --- check input ---
    if not torch.is_tensor(x):
        raise ValueError("x must be a torch.Tensor")
    if x.ndim != 2:
        raise ValueError("x must have shape [T, F]")
        
    T, F = x.shape
    time_idx = np.arange(T)
    time_sec = time_idx / float(fps)

    arr = x.detach().cpu().numpy()

    # --- rolling mean/std ---
    if win is None:
        win = max(3, fps)  # default = ~1 second
    half = win // 2

    roll_mean, roll_std = rolling_mean_std(arr, win)
    roll_t = time_sec[half:half+len(roll_mean)]

    # ---------- 1) Time series ----------
    fig_ts, axes = plt.subplots(F, 1, figsize=(10, 2.5*F), sharex=True)
    if F == 1:
        axes = [axes]

    for f in range(F):
        ax = axes[f]
        ax.plot(time_sec, arr[:, f], alpha=0.35, linewidth=1.0, label=f'Feature {f}')
        ax.plot(roll_t, roll_mean[:, f], linewidth=2.0, label=f'Rolling mean (w={win})')
        ax.fill_between(roll_t,
                        roll_mean[:, f] - roll_std[:, f],
                        roll_mean[:, f] + roll_std[:, f],
                        alpha=0.2, label='±1 std (rolling)')
        ax.set_ylabel(f'Feature {f}')
        ax.grid(True, linestyle='--', alpha=0.3)
    axes[-1].set_xlabel('Time (s)')
    axes[0].legend(loc='upper right')
    fig_ts.suptitle('Per-feature time series with rolling mean ± std', y=1.02)
    fig_ts.tight_layout()
    fig_ts.savefig(f"output/{out_prefix}_time_series.png", dpi=200)

    # ---------- 2) Heatmap ----------
    fig_hm, ax = plt.subplots(figsize=(10, 2.8))
    arr_min = arr.min(axis=0, keepdims=True)
    arr_max = arr.max(axis=0, keepdims=True)
    arr_norm = (arr - arr_min) / (arr_max - arr_min + 1e-12)

    im = ax.imshow(arr_norm.T, aspect='auto', interpolation='nearest',
                   extent=[time_sec[0], time_sec[-1], F-0.5, -0.5])
    ax.set_yticks(np.arange(F))
    ax.set_yticklabels([f'Feat {f}' for f in range(F)])
    ax.set_xlabel('Time (s)')
    ax.set_title('Heatmap (per-feature normalized)')
    fig_hm.colorbar(im, ax=ax, fraction=0.025, pad=0.02)
    fig_hm.tight_layout()
    fig_hm.savefig(f"output/{out_prefix}_heatmap.png", dpi=200)

    # ---------- 3) Boxplots ----------
    fig_box, ax = plt.subplots(figsize=(7, 3.5))
    ax.boxplot([arr[:, f] for f in range(F)], showmeans=True)
    ax.set_xticklabels([f'Feat {f}' for f in range(F)])
    ax.set_ylabel('Value')
    ax.set_title('Distribution across time (boxplot per feature)')
    ax.grid(True, axis='y', linestyle='--', alpha=0.3)
    fig_box.tight_layout()
    fig_box.savefig(f"output/{out_prefix}_boxplots.png", dpi=200)

    print(f"Saved: {out_prefix}_time_series.png, {out_prefix}_heatmap.png, {out_prefix}_boxplots.png")

def Calc_Eval_table(
    TrainModel,
    TestDataLoadre:DataLoader, 
    Youden=False, 
    modelName="", 
):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []
     
    with torch.no_grad():
        for batch in tqdm(iter(TestDataLoadre)):
            
            data, label, mask, _, key_frame = batch
            
            data = data.float().to(device)
            label = label.to(device)
            key_frame = key_frame.float().to(device)
            mask = mask.to(device)
            
            #pred, *_ = TrainModel(data)
            pred, _ = TrainModel(data, key_frame, mask)
            pred = pred.flatten()
                
            Predictions.append(torch.sigmoid(pred))
            Labels.append(label)
      
    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels).int()

    #print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
        
    table = [
        ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
    ]

    if Youden:
        for i in range(2):
            aucCurve = torchmetrics.ROC("binary").to(device)
            fpr, tpr, thhols = aucCurve(Predictions, Labels)
            index = torch.argmax(tpr - fpr)
            th2 = (recall + specificty - 1).item()
            th2 = 0.5 if th2 <= 0 else th2
            th1 = thhols[index].item() if i == 0 else th2
            accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
            precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
            recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
            specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
            f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
            #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
            #cm2.update(Predictions, Labels)
            #_, ax = cm2.plot()
            #ax.set_title(f"NVB Classifier (th={th1:.4f})")
            table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
        

    return table
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'", required=True)
    parser.add_argument("--Fold", type=int, default=0)
    parser.add_argument("-lv","--Log_version", type=int, default=None)
    parser.add_argument("--Workers", type=int, default=0)
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("--CNN_name", type=str, default=None, )
    parser.add_argument("--Temp_Head", type=str, default=None, )
    parser.add_argument("-me", "--maxEpochs", type=int, default=None)
    parser.add_argument("-b", "--Batch_size", type=int, default=8)
    parser.add_argument("--GPU", type=int, default=0)
    parser.add_argument("--pre_train", type=int, default=0)
    parser.add_argument("-k", "--k_windows", type=int, default=1)
    
    args = parser.parse_args()
    
    setup_seed(2023)
    device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")

    df = pd.read_csv("../dataset/Dataset_RARP_video/dataset_videos_folds.csv")

    FOLD = args.Fold
    WORKERS = args.Workers
    BATCH_SIZE = args.Batch_size
    MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
    PRE_TRAIN = args.pre_train != 0 
    K_WIN = args.k_windows
    KEY_FRAME = True
    
    Mean = [0.485, 0.456, 0.406]
    Std = [0.229, 0.224, 0.225]
        
    print(f"Fold_{FOLD}")
    
    ckpt_paths = [
        Path("./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt"),
        Path("./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt"),
    ]

    train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
    val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
    test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
    
    traintransformT2 = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=0),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.GaussianBlur(kernel_size=3),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ).to(device)

    traintransform_frame = torch.nn.Sequential(
        transforms.RandomApply([
            transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
            transforms.RandomErasing(1.0, value=0)
            ], 0.3) #small noise
    ).to(device)
    
    testVal_transform = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ).to(device)
    
    key_frame_transform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    ).to(device)

    train_dataset = Loaders.RARP_Windowed_Video_frames_Dataset(
        train_set, 
        train_mode=True, 
        window_length=64, 
        transform=traintransformT2, 
        transform_frame=traintransform_frame, 
        k_windows=K_WIN,
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform
    )
    val_dataset = Loaders.RARP_Windowed_Video_frames_Dataset(
        val_set, 
        train_mode=False, 
        window_length=64, 
        stride=32, 
        transform=testVal_transform,
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform
    )
    test_dataset = Loaders.RARP_Windowed_Video_frames_Dataset(
        test_set, 
        train_mode=False, 
        window_length=64, 
        stride=32, 
        transform=testVal_transform,
        key_frames=KEY_FRAME,
        key_frame_transform=key_frame_transform
    )    
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    LogFileName = f"{args.Log_Name}" 
    
    match(args.Phase):
        case "train":
            if not KEY_FRAME:
                Model = M.RARP_NVB_Wind_video(
                    num_classes=1,
                    temporal=args.Temp_Head,
                    cnn_name=args.CNN_name,
                    dropout=0.2,
                    lr=1e-4, #3e-4,
                    weight_decay=0.1, #0.05
                    epochs=MAX_EPOCHS,
                    warmup_epochs=3,
                    pre_train=PRE_TRAIN
                )
            else:
                Model = M.RARP_NVB_Multi_MOD(
                    num_classes=1,
                    temporal=args.Temp_Head,
                    cnn_name=args.CNN_name,
                    dropout=0.2,
                    lr=1e-4, #3e-4,
                    weight_decay=0.1, #0.05
                    epochs=MAX_EPOCHS,
                    warmup_epochs=3,
                    pre_train=PRE_TRAIN,
                    Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve())
                )
            
            print(f"Model Used: {type(Model).__name__}")            
            
            checkPtCallback = [
                callbk.ModelCheckpoint(monitor='val_video_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
                #callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
            ]
            
            trainer = L.Trainer(
                deterministic=True,
                accelerator="gpu",
                devices=[args.GPU],
                logger=TensorBoardLogger(save_dir=f"./{LogFileName}") if args.Phase == "train" else CSVLogger(save_dir=f"./{LogFileName}/Test", version=args.Log_version),
                log_every_n_steps=5,  
                callbacks=checkPtCallback,
                max_epochs=MAX_EPOCHS
            )
            
            print("Train Phase")
            trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
            trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
        case "eval_all":
            print("Evaluation Phase")
            rows = []
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
            for ckpFile in sorted(pathCkptFile.glob("*.ckpt")):
                print(ckpFile.name)
                #trainer.test(Model, dataloaders=test_loader, ckpt_path=ckpFile)
                #Model = M.RARP_NVB_DINO_MultiTask_A5_Video.load_from_checkpoint(ckpFile)
                if KEY_FRAME:
                    hp_fiel = pathCkptFile.parent / "hparams.yaml"
                    Model = M.RARP_NVB_Multi_MOD.load_from_checkpoint(ckpFile, map_location=device, hparams_file=hp_fiel)
                
                temp = Calc_Eval_table(
                        Model, 
                        test_loader, 
                        True, 
                        ckpFile.name
                    )
                rows += temp
            
            df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])        
            #df.style.highlight_max(color="red", axis=0)
            output_file = Path(f"./{LogFileName}/output.xlsx")             
            if not output_file.exists():
                df.to_excel(output_file, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
            else:
                with pd.ExcelWriter(output_file, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
                    df.to_excel(writer, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
            print("[END] File saved ... ")