import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models_video as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
def Calc_Eval_table(
TrainModel,
TestDataLoadre:DataLoader,
Youden=False,
modelName="",
):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
with torch.no_grad():
for data, label, mask, key_frame in tqdm(iter(TestDataLoadre)):
data = data.to(device, dtype=torch.float32)
key_frame = key_frame.to(device, dtype=torch.float32)
mask = mask.to(device)
label = label.to(device)
#pred, *_ = TrainModel(data)
pred, _ = TrainModel(data, key_frame, mask)
pred = pred.flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels).int()
#print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
]
if Youden:
for i in range(2):
aucCurve = torchmetrics.ROC("binary").to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
return table
def ensure_list(x):
# Converts tensor/list/tuple to python list
if torch.is_tensor(x):
return x.detach().cpu().tolist()
if isinstance(x, (list, tuple)):
return list(x)
return [x]
def compute_uniform_starts(T: int, L: int, W: int):
stride = (T - L) / max(W - 1, 1)
starts = []
for i in range(W):
s = int(round(i * stride))
s = min(s, T - L)
starts.append(s)
return starts
def get_topk_nearest_to_centroids(df: pd.DataFrame, X_space: np.ndarray, kmeans, topk: int = 20):
"""
df: rows correspond 1:1 with X_space (same ordering)
X_space: embeddings used for KMeans
kmeans: fitted sklearn KMeans
Returns: df_top with extra columns: dist_to_centroid, rank_in_cluster
"""
assert len(df) == X_space.shape[0], "df and X_space must have same #rows"
C = kmeans.cluster_centers_ # [K, D]
labels = df["cluster"].to_numpy()
out_rows = []
for k in range(C.shape[0]):
idx = np.where(labels == k)[0]
if len(idx) == 0:
continue
d = np.linalg.norm(X_space[idx] - C[k], axis=1)
order = np.argsort(d)[:topk]
chosen = idx[order]
sub = df.iloc[chosen].copy()
sub["dist_to_centroid"] = d[order]
sub["rank_in_cluster"] = np.arange(1, len(sub) + 1)
out_rows.append(sub)
df_top = pd.concat(out_rows, axis=0).sort_values(["cluster", "rank_in_cluster"])
return df_top
def save_window_montage(clip_uint8, out_path, title="", n_frames=6):
"""
clip_uint8: [L, H, W, 3] uint8 (preferred)
Saves montage with n_frames sampled from the clip.
"""
L = clip_uint8.shape[0]
# evenly spaced indices
idx = np.linspace(0, L - 1, n_frames).round().astype(int)
fig, axes = plt.subplots(1, n_frames, figsize=(2.2*n_frames, 2.2))
if n_frames == 1:
axes = [axes]
for ax, t in zip(axes, idx):
img = clip_uint8[t]
# safety: if CHW, convert to HWC
if img.ndim == 3 and img.shape[0] == 3 and img.shape[-1] != 3:
img = np.transpose(img, (1, 2, 0))
ax.imshow(img)
ax.set_title(f"t={t}")
ax.axis("off")
fig.suptitle(title, fontsize=10)
plt.tight_layout()
plt.savefig(out_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def load_window_rgb(arrays, case_to_vidx, case_id, start_sec, L):
vidx = case_to_vidx[case_id]
arr = arrays[vidx] # memmap/ndarray
clip = arr[start_sec:start_sec+L] # [L, H, W, 3] uint8
return clip
def export_cluster_examples(df_top,
arrays,
case_to_vidx,
L_win: int,
out_dir = "./cluster_examples",
n_frames_per_montage: int = 6):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for row in tqdm(df_top.itertuples(index=False), desc="Centroids Montage"):
# expects these columns exist in df_top:
# case_id, cluster, win_start_sec, win_idx, dist_to_centroid, rank_in_cluster
cid = row.case_id
k = int(row.cluster)
s = int(row.win_start_sec)
w = int(row.win_idx)
dist = float(row.dist_to_centroid)
rank = int(row.rank_in_cluster)
folder = out_dir / f"cluster_{k:02d}"
folder.mkdir(exist_ok=True)
clip = load_window_rgb(arrays, case_to_vidx, cid, s, L_win)
out_path = folder / f"rank{rank:02d}_case{cid}_w{w:02d}_s{s:04d}_d{dist:.4f}.png"
title = f"cluster={k} rank={rank} case={cid} win={w} start={s}s dist={dist:.4f}"
save_window_montage(clip, str(out_path.resolve()), title=title, n_frames=n_frames_per_montage)
print(f"Saved cluster example montages to: {out_dir}")
def extract_and_cluster_windows(
test_loader,
encoder,
device,
out_dir="./cluster_out",
n_clusters=4,
pca_dim=128,
pca_vis_dim=2,
random_state=505,
# fallback params if win_start not in info
T_total=1200, # 20 min @ 1 fps
L_win=None, # will infer from data if None
vid_array = None,
dict_vid_array = None,
num_samples:int = 1,
Hybrid_TS=None
):
out_dir = Path(out_dir+f"K{n_clusters}")
out_dir.mkdir(parents=True, exist_ok=True)
features = []
case_ids = []
win_idx = []
win_start = []
y_video = []
with torch.no_grad():
for winds, label, _, info in tqdm(test_loader, desc="Windows Analysis"):
# winds: [B, N, L, C, H, W]
B, N, L, C, H, W = winds.shape
if L_win is None:
L_win = L
# middle frame per window
if num_samples == 1:
mid = winds[:, :, L // 2] # [B, N, C, H, W]
mid = mid.reshape(B*N, C, H, W) # [B*N, C, H, W]
else:
idx = np.linspace(0, L - 1, num_samples).round().astype(int)
mid = winds[:, :, idx]
K = mid.shape[2]
mid = mid.reshape(B*N*K, C, H, W) # [B*N*K, C, H, W]
if Hybrid_TS is None:
f = encoder(mid.to(device, non_blocking=True)) # [B*N, D] or [B*N*K, D]
else:
_ = Hybrid_TS(mid.to(device, non_blocking=True))
_fs = torch.cat((Hybrid_TS.last_conv_output_S, Hybrid_TS.last_conv_output_T), dim=1)
f = torch.nn.functional.adaptive_avg_pool2d(_fs, (1,1)).flatten(1)
f = torch.nn.functional.normalize(f, dim=1)
if num_samples > 1:
f = f.view(B, N, K, -1).mean(dim=2) #mean the K dim
f = f.reshape(B*N, -1)
features.append(f.cpu())
# labels: video-level (B,) or (B,1). Repeat per window
if torch.is_tensor(label):
label_b = label.detach().cpu().view(B).tolist()
else:
label_b = list(label)
# case_ids: list length B
cids = info.get("case_id", info.get("case", None))
cids = ensure_list(cids)
# window indices
# Preferred: info["win_idx"] is [B, N]
if "win_idx" in info:
widx = info["win_idx"]
if torch.is_tensor(widx):
widx = widx.detach().cpu().view(B, N).tolist()
# flatten by batch item
for b in range(B):
case_ids.extend([cids[b]] * N)
win_idx.extend(widx[b])
y_video.extend([label_b[b]] * N)
else:
# fallback: use 0..N-1
for b in range(B):
case_ids.extend([cids[b]] * N)
win_idx.extend(list(range(N)))
y_video.extend([label_b[b]] * N)
# window start seconds
# Preferred: info["win_start"] is [B, N]
if "win_start" in info:
ws = info["win_start"]
if torch.is_tensor(ws):
ws = ws.detach().cpu().view(B, N).tolist()
for b in range(B):
win_start.extend(ws[b])
else:
# fallback: compute from uniform starts, same for every case
starts = compute_uniform_starts(T=T_total, L=L_win, W=N)
# repeat for B cases
win_start.extend(starts * B)
X = torch.cat(features, dim=0).numpy() # [TotalWindows, D]
df = pd.DataFrame({
"case_id": case_ids,
"win_idx": win_idx,
"win_start_sec": win_start,
"label_nvb": y_video
})
df["win_start_min"] = round (df["win_start_sec"] / 60.0, 4)
# PCA for clustering
X_pca = PCA(n_components=pca_dim, random_state=random_state).fit_transform(X)
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto")
df["cluster"] = kmeans.fit_predict(X_pca)
sil = silhouette_score(X_pca, df["cluster"].values)
print(f"Silhouette (PCA-{pca_dim}, K={n_clusters}): {sil:.6f}")
df_top = get_topk_nearest_to_centroids(df, X_pca, kmeans, topk=5)
# Save table
csv_path = out_dir / "windows_clusters.csv"
df.to_csv(csv_path, index=False)
print(f"Saved: {str(csv_path)}")
csv_path = out_dir / "cluster_topk_centroid_nearest.csv"
df_top.to_csv(csv_path, index=False)
print(f"Saved: {str(csv_path)}")
#Montage of centroids
export_cluster_examples(df_top, arrays=vid_array, case_to_vidx=dict_vid_array, L_win=L_win, out_dir=(out_dir/"centroids"), n_frames_per_montage=6)
# Plot 1: PCA-2D scatter colored by cluster
X_vis = PCA(n_components=pca_vis_dim, random_state=random_state).fit_transform(X)
plt.figure(figsize=(7, 6))
plt.scatter(X_vis[:, 0], X_vis[:, 1], c=df["cluster"].values, s=6, alpha=0.5)
plt.title(f"Window Embeddings (PCA-{pca_vis_dim}) — KMeans K={n_clusters}")
plt.xlabel("PC1"); plt.ylabel("PC2")
plt.grid(True, alpha=0.2)
pca_path = out_dir / "pca2d_clusters.png"
plt.savefig(str(pca_path.resolve()), dpi=200, bbox_inches="tight")
plt.close()
print(f"Saved: {str(pca_path)}")
# Plot 2: cluster timeline for a few cases
# (Pick some cases; you can also pass a list explicitly)
unique_cases = df["case_id"].unique().tolist()
show_cases = unique_cases[:min(6, len(unique_cases))]
_, axes = plt.subplots(len(show_cases), 1, figsize=(10, 2.0 * len(show_cases)), sharex=True)
if len(show_cases) == 1:
axes = [axes]
for ax, cid in zip(axes, show_cases):
sub = df[df["case_id"] == cid].sort_values("win_start_sec")
ax.scatter(sub["win_start_min"], sub["cluster"], s=20)
ax.set_ylabel("Cluster")
ax.set_title(f"Case {cid} — cluster vs time")
ax.grid(True)
axes[-1].set_xlabel("Time (minutes)")
timeline_path = out_dir / "cluster_timelines.png"
plt.tight_layout()
plt.savefig(str(timeline_path.resolve()), dpi=200, bbox_inches="tight")
plt.close()
print(f"Saved: {str(timeline_path)}")
return df, sil
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'", required=True)
parser.add_argument("--Fold", type=int, default=0)
parser.add_argument("-lv","--Log_version", type=int, default=None)
parser.add_argument("--Workers", type=int, default=0)
parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
parser.add_argument("--CNN_name", type=str, default=None, )
parser.add_argument("--Temp_Head", type=str, default=None, )
parser.add_argument("-me", "--maxEpochs", type=int, default=None)
parser.add_argument("-b", "--Batch_size", type=int, default=8)
parser.add_argument("--GPU", type=int, default=0)
parser.add_argument("--pre_train", type=int, default=0)
parser.add_argument("-k", "--k_windows", type=int, default=1)
parser.add_argument("--Window_Size", type=int, default=64)
parser.add_argument("--Num_Window", type=int, default=8)
parser.add_argument("--cached_features", type=bool, default=False)
parser.add_argument("--seed", type=int, default=2023)
args = parser.parse_args()
setup_seed(args.seed)
device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")
df = pd.read_csv("./Dataset_RARP_video/dataset_videos_folds.csv")
FOLD = args.Fold
WORKERS = args.Workers
BATCH_SIZE = args.Batch_size
MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
PRE_TRAIN = args.pre_train != 0
K_WIN = args.k_windows
KEY_FRAME = True if args.Phase != "cluster" else False
WIN_LENGTH = args.Window_Size
NUM_WIN = args.Num_Window
CACHED_FEATURES = args.cached_features
NO_NORM_VIDEO = False if args.CNN_name != "hybrid_t-s" else True
Mean = [0.485, 0.456, 0.406]
Std = [0.229, 0.224, 0.225]
print(f"Fold_{FOLD}")
ckpt_paths = [
Path("./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt"),
]
train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
traintransformT2 = torch.nn.Sequential(
transforms.CenterCrop(300),
transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=0),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
transforms.GaussianBlur(kernel_size=3),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(device)
traintransform_frame = torch.nn.Sequential(
transforms.RandomApply([
transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
transforms.RandomErasing(1.0, value="random")
], 0.3) #small noise
).to(device)
testVal_transform = torch.nn.Sequential(
transforms.CenterCrop(300),
transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(device)
testVal_transform_GSViT = torch.nn.Sequential(
transforms.CenterCrop(300),
transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
).to(device)
key_frame_transform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
).to(device)
train_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
train_set,
train_mode=True,
num_windows=NUM_WIN,
window_length=WIN_LENGTH,
transform=traintransformT2,
transform_frame=traintransform_frame,
key_frames=KEY_FRAME,
key_frame_transform=key_frame_transform,
load_key_frame_cache=CACHED_FEATURES,
Fold_index=FOLD
)
val_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
val_set,
train_mode=False,
num_windows=NUM_WIN,
window_length=WIN_LENGTH,
transform=testVal_transform,
key_frames=KEY_FRAME,
key_frame_transform=key_frame_transform,
load_key_frame_cache=CACHED_FEATURES,
Fold_index=FOLD
)
test_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
test_set,
train_mode=False,
num_windows=NUM_WIN,
window_length=WIN_LENGTH,
transform=testVal_transform if args.CNN_name != "hybrid_t-s" else testVal_transform_GSViT,
key_frames=KEY_FRAME,
key_frame_transform=key_frame_transform,
load_key_frame_cache=CACHED_FEATURES,
Fold_index=FOLD,
no_norm_video=NO_NORM_VIDEO
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
test_loader = DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
LogFileName = f"{args.Log_Name}"
checkPtCallback = [
callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
#callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
]
trainer = L.Trainer(
precision="32-true" if args.CNN_name == "gsvit" else "16-mixed",
deterministic=True,
accelerator="gpu",
devices=[args.GPU],
#devices=[0, 1], strategy="ddp",
logger=TensorBoardLogger(save_dir=f"./{LogFileName}") if args.Phase == "train" else CSVLogger(save_dir=f"./{LogFileName}/Test", version=args.Log_version),
log_every_n_steps=5,
callbacks=checkPtCallback,
max_epochs=MAX_EPOCHS
)
match(args.Phase):
case "cache_key_frame":
from Models import RARP_NVB_DINO_MultiTask
print (f"Load Export model for the FOLD #{FOLD}")
Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(ckpt_paths[FOLD], map_location=device)
Hybrid_TS.eval()
namelist = ["TRAIN", "VAL", "TEST"]
for _i, _s in enumerate([train_set, val_set, test_set]):
print (f"[{namelist[_i]} Set] of FOLD # {FOLD}")
key_frame_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
_s,
key_frames=True,
key_frame_transform=key_frame_transform,
key_frame_only=True,
)
key_frameloader = DataLoader(
key_frame_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
print (f"[SAVE] caching Image features and Soft lables from Expert Model in FOLD #{FOLD}")
with torch.no_grad():
for img, case_id in tqdm(iter(key_frameloader)):
B = img.shape[0]
img = img.to(device, dtype=torch.float)
Soft_label, _, _ = Hybrid_TS(img)
Img_features = torch.cat((Hybrid_TS.last_conv_output_S, Hybrid_TS.last_conv_output_T), dim=1)
Img_features = torch.nn.functional.adaptive_avg_pool2d(Img_features, (1,1)).flatten(1)
for i in range(B):
parent_path = next((r for r in _s if r.get("case") == case_id[i]), None)
parent_path = Path(parent_path["path"]).resolve().parent
parent_path = parent_path / "chache"
parent_path.mkdir(exist_ok=True)
np.savez((parent_path / f"F{FOLD}_{case_id[i]}.npz"), soft_label=Soft_label[i].cpu().numpy(), img_features=Img_features[i].cpu().numpy())
print (f"[DONE] FOLD #{FOLD}")
case "train":
Model = M.RARP_NVB_Multi_MOD_MIL(
num_classes=1,
temporal=args.Temp_Head,
cnn_name=args.CNN_name,
dropout=0.3,
lr=1e-4, #3e-4,
weight_decay=0.1, #0.05
epochs=MAX_EPOCHS,
pre_train=PRE_TRAIN,
Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve()) if not CACHED_FEATURES else None,
FOLD=FOLD,
attn_entropy_target=0.4,
attn_reg_warmup_epochs=5,
attn_reg_weight=0.02
)
print(f"Model Used: {type(Model).__name__}")
print("Train Phase")
trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
case "eval_all":
print("Evaluation Phase")
rows = []
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
for ckpFile in sorted(pathCkptFile.glob("*.ckpt")):
print(ckpFile.name)
#trainer.test(Model, dataloaders=test_loader, ckpt_path=ckpFile)
#Model = M.RARP_NVB_DINO_MultiTask_A5_Video.load_from_checkpoint(ckpFile)
hp_fiel = pathCkptFile.parent / "hparams.yaml"
Model = M.RARP_NVB_Multi_MOD_MIL_TESTMode.load_from_checkpoint(ckpFile, map_location=device, hparams_file=hp_fiel)
trainer.test(Model, dataloaders=test_loader)
#temp = Calc_Eval_table(Model, test_loader, True, ckpFile.name)
temp = Model._test_results
rows += temp
df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])
#df.style.highlight_max(color="red", axis=0)
output_file = Path(f"./{LogFileName}/output.xlsx")
if not output_file.exists():
df.to_excel(output_file, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
else:
with pd.ExcelWriter(output_file, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
df.to_excel(writer, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
print("[END] File saved ... ")
case "cluster":
ckpt_paths_MIL = [
Path("./log_XT6/lightning_logs/version_0/checkpoints/RARP-epoch=23.ckpt"),
Path("./log_XT6/lightning_logs/version_1/checkpoints/RARP-epoch=27.ckpt"),
Path("./log_XT6/lightning_logs/version_2/checkpoints/RARP-epoch=20.ckpt"),
Path("./log_XT6/lightning_logs/version_3/checkpoints/RARP-epoch=30.ckpt"),
Path("./log_XT6/lightning_logs/version_4/checkpoints/RARP-epoch=29.ckpt"),
]
hp_file = ckpt_paths_MIL[FOLD].parent.parent / "hparams.yaml"
Model = M.RARP_NVB_Multi_MOD_MIL.load_from_checkpoint(ckpt_paths_MIL[FOLD], map_location=device, hparams_file=hp_file)
Model = Model.to(device)
Model.eval()
encoder = None
df, sil = extract_and_cluster_windows(
test_loader,
encoder,
device,
random_state=0,
out_dir=f"./{args.CNN_name}_cluster_out_F{FOLD}",
vid_array=test_dataset.arrays,
dict_vid_array=test_dataset.case_index,
num_samples=1,
n_clusters=3,
Hybrid_TS=Model.Hybrid_TS
)