import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models_video as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
def rolling_mean_std(a, w):
csum = np.cumsum(a, axis=0)
csum = np.pad(csum, ((1,0),(0,0)), mode="constant")
win_sum = csum[w:] - csum[:-w]
mean = win_sum / float(w)
sq = a**2
csum_sq = np.cumsum(sq, axis=0)
csum_sq = np.pad(csum_sq, ((1,0),(0,0)), mode="constant")
win_sum_sq = csum_sq[w:] - csum_sq[:-w]
var = (win_sum_sq / float(w)) - mean**2
std = np.sqrt(np.maximum(var, 1e-12))
return mean, std
def plot_tensor_analysis(x, fps=30, win=None, out_prefix="tensor_analysis"):
"""
Visualize a tensor of shape [T, F] with:
1) Time series per feature (raw + rolling mean ± std)
2) Heatmap overview (per-feature normalized to [0,1])
3) Distribution boxplots per feature
Args:
x (torch.Tensor): Input tensor of shape [T, F].
fps (int): Frames per second (for x-axis in seconds).
win (int or None): Rolling window size in frames. Default = fps.
out_prefix (str): Prefix for saved file names.
"""
# --- check input ---
if not torch.is_tensor(x):
raise ValueError("x must be a torch.Tensor")
if x.ndim != 2:
raise ValueError("x must have shape [T, F]")
T, F = x.shape
time_idx = np.arange(T)
time_sec = time_idx / float(fps)
arr = x.detach().cpu().numpy()
# --- rolling mean/std ---
if win is None:
win = max(3, fps) # default = ~1 second
half = win // 2
roll_mean, roll_std = rolling_mean_std(arr, win)
roll_t = time_sec[half:half+len(roll_mean)]
# ---------- 1) Time series ----------
fig_ts, axes = plt.subplots(F, 1, figsize=(10, 2.5*F), sharex=True)
if F == 1:
axes = [axes]
for f in range(F):
ax = axes[f]
ax.plot(time_sec, arr[:, f], alpha=0.35, linewidth=1.0, label=f'Feature {f}')
ax.plot(roll_t, roll_mean[:, f], linewidth=2.0, label=f'Rolling mean (w={win})')
ax.fill_between(roll_t,
roll_mean[:, f] - roll_std[:, f],
roll_mean[:, f] + roll_std[:, f],
alpha=0.2, label='±1 std (rolling)')
ax.set_ylabel(f'Feature {f}')
ax.grid(True, linestyle='--', alpha=0.3)
axes[-1].set_xlabel('Time (s)')
axes[0].legend(loc='upper right')
fig_ts.suptitle('Per-feature time series with rolling mean ± std', y=1.02)
fig_ts.tight_layout()
fig_ts.savefig(f"output/{out_prefix}_time_series.png", dpi=200)
# ---------- 2) Heatmap ----------
fig_hm, ax = plt.subplots(figsize=(10, 2.8))
arr_min = arr.min(axis=0, keepdims=True)
arr_max = arr.max(axis=0, keepdims=True)
arr_norm = (arr - arr_min) / (arr_max - arr_min + 1e-12)
im = ax.imshow(arr_norm.T, aspect='auto', interpolation='nearest',
extent=[time_sec[0], time_sec[-1], F-0.5, -0.5])
ax.set_yticks(np.arange(F))
ax.set_yticklabels([f'Feat {f}' for f in range(F)])
ax.set_xlabel('Time (s)')
ax.set_title('Heatmap (per-feature normalized)')
fig_hm.colorbar(im, ax=ax, fraction=0.025, pad=0.02)
fig_hm.tight_layout()
fig_hm.savefig(f"output/{out_prefix}_heatmap.png", dpi=200)
# ---------- 3) Boxplots ----------
fig_box, ax = plt.subplots(figsize=(7, 3.5))
ax.boxplot([arr[:, f] for f in range(F)], showmeans=True)
ax.set_xticklabels([f'Feat {f}' for f in range(F)])
ax.set_ylabel('Value')
ax.set_title('Distribution across time (boxplot per feature)')
ax.grid(True, axis='y', linestyle='--', alpha=0.3)
fig_box.tight_layout()
fig_box.savefig(f"output/{out_prefix}_boxplots.png", dpi=200)
print(f"Saved: {out_prefix}_time_series.png, {out_prefix}_heatmap.png, {out_prefix}_boxplots.png")
def Calc_Eval_table(
TrainModel,
TestDataLoadre:DataLoader,
Youden=False,
modelName="",
):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
with torch.no_grad():
for batch in tqdm(iter(TestDataLoadre)):
data, label, mask, _, key_frame = batch
data = data.float().to(device)
label = label.to(device)
key_frame = key_frame.float().to(device)
mask = mask.to(device)
#pred, *_ = TrainModel(data)
pred, _ = TrainModel(data, key_frame, mask)
pred = pred.flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels).int()
#print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
]
if Youden:
for i in range(2):
aucCurve = torchmetrics.ROC("binary").to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
return table
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'", required=True)
parser.add_argument("--Fold", type=int, default=0)
parser.add_argument("-lv","--Log_version", type=int, default=None)
parser.add_argument("--Workers", type=int, default=0)
parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
parser.add_argument("--CNN_name", type=str, default=None, )
parser.add_argument("--Temp_Head", type=str, default=None, )
parser.add_argument("-me", "--maxEpochs", type=int, default=None)
parser.add_argument("-b", "--Batch_size", type=int, default=8)
parser.add_argument("--GPU", type=int, default=0)
parser.add_argument("--pre_train", type=int, default=0)
parser.add_argument("-k", "--k_windows", type=int, default=1)
parser.add_argument("--Window_Size", type=int, default=64)
args = parser.parse_args()
setup_seed(2023)
device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")
df = pd.read_csv("./Dataset_RARP_video/dataset_videos_folds.csv")
FOLD = args.Fold
WORKERS = args.Workers
BATCH_SIZE = args.Batch_size
MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
PRE_TRAIN = args.pre_train != 0
K_WIN = args.k_windows
KEY_FRAME = True
WIN_LENGTH = args.Window_Size
Mean = [0.485, 0.456, 0.406]
Std = [0.229, 0.224, 0.225]
print(f"Fold_{FOLD}")
ckpt_paths = [
Path("./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt"),
]
train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
traintransformT2 = torch.nn.Sequential(
transforms.CenterCrop(300),
transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=0),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
transforms.GaussianBlur(kernel_size=3),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(device)
traintransform_frame = torch.nn.Sequential(
transforms.RandomApply([
transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
transforms.RandomErasing(1.0, value="random")
], 0.3) #small noise
).to(device)
testVal_transform = torch.nn.Sequential(
transforms.CenterCrop(300),
transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(device)
key_frame_transform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
).to(device)
train_dataset = Loaders.RARP_Windowed_Video_frames_Dataset(
train_set,
train_mode=True,
window_length=WIN_LENGTH,
transform=traintransformT2,
transform_frame=traintransform_frame,
k_windows=K_WIN,
key_frames=KEY_FRAME,
key_frame_transform=key_frame_transform
)
val_dataset = Loaders.RARP_Windowed_Video_frames_Dataset(
val_set,
train_mode=False,
window_length=WIN_LENGTH,
stride=32,
transform=testVal_transform,
key_frames=KEY_FRAME,
key_frame_transform=key_frame_transform
)
test_dataset = Loaders.RARP_Windowed_Video_frames_Dataset(
test_set,
train_mode=False,
window_length=WIN_LENGTH,
stride=32,
transform=testVal_transform,
key_frames=KEY_FRAME,
key_frame_transform=key_frame_transform
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
test_loader = DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
LogFileName = f"{args.Log_Name}"
match(args.Phase):
case "train":
if not KEY_FRAME:
Model = M.RARP_NVB_Wind_video(
num_classes=1,
temporal=args.Temp_Head,
cnn_name=args.CNN_name,
dropout=0.2,
lr=1e-4, #3e-4,
weight_decay=0.1, #0.05
epochs=MAX_EPOCHS,
warmup_epochs=3,
pre_train=PRE_TRAIN
)
else:
Model = M.RARP_NVB_Multi_MOD(
num_classes=1,
temporal=args.Temp_Head,
cnn_name=args.CNN_name,
dropout=0.2,
lr=1e-4, #3e-4,
weight_decay=0.1, #0.05
epochs=MAX_EPOCHS,
warmup_epochs=3,
pre_train=PRE_TRAIN,
Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve())
)
print(f"Model Used: {type(Model).__name__}")
checkPtCallback = [
callbk.ModelCheckpoint(monitor='val_video_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
#callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
]
trainer = L.Trainer(
deterministic=True,
accelerator="gpu",
devices=[args.GPU],
logger=TensorBoardLogger(save_dir=f"./{LogFileName}") if args.Phase == "train" else CSVLogger(save_dir=f"./{LogFileName}/Test", version=args.Log_version),
log_every_n_steps=5,
callbacks=checkPtCallback,
max_epochs=MAX_EPOCHS
)
print("Train Phase")
trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
case "eval_all":
print("Evaluation Phase")
rows = []
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
for ckpFile in sorted(pathCkptFile.glob("*.ckpt")):
print(ckpFile.name)
#trainer.test(Model, dataloaders=test_loader, ckpt_path=ckpFile)
#Model = M.RARP_NVB_DINO_MultiTask_A5_Video.load_from_checkpoint(ckpFile)
if KEY_FRAME:
hp_fiel = pathCkptFile.parent / "hparams.yaml"
Model = M.RARP_NVB_Multi_MOD.load_from_checkpoint(ckpFile, map_location=device, hparams_file=hp_fiel)
temp = Calc_Eval_table(
Model,
test_loader,
True,
ckpFile.name
)
rows += temp
df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])
#df.style.highlight_max(color="red", axis=0)
output_file = Path(f"./{LogFileName}/output.xlsx")
if not output_file.exists():
df.to_excel(output_file, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
else:
with pd.ExcelWriter(output_file, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
df.to_excel(writer, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
print("[END] File saved ... ")