import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as PL
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models_video as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
def rolling_mean_std(a, w):
csum = np.cumsum(a, axis=0)
csum = np.pad(csum, ((1,0),(0,0)), mode="constant")
win_sum = csum[w:] - csum[:-w]
mean = win_sum / float(w)
sq = a**2
csum_sq = np.cumsum(sq, axis=0)
csum_sq = np.pad(csum_sq, ((1,0),(0,0)), mode="constant")
win_sum_sq = csum_sq[w:] - csum_sq[:-w]
var = (win_sum_sq / float(w)) - mean**2
std = np.sqrt(np.maximum(var, 1e-12))
return mean, std
def plot_tensor_analysis(x, fps=30, win=None, out_prefix="tensor_analysis"):
"""
Visualize a tensor of shape [T, F] with:
1) Time series per feature (raw + rolling mean ± std)
2) Heatmap overview (per-feature normalized to [0,1])
3) Distribution boxplots per feature
Args:
x (torch.Tensor): Input tensor of shape [T, F].
fps (int): Frames per second (for x-axis in seconds).
win (int or None): Rolling window size in frames. Default = fps.
out_prefix (str): Prefix for saved file names.
"""
# --- check input ---
if not torch.is_tensor(x):
raise ValueError("x must be a torch.Tensor")
if x.ndim != 2:
raise ValueError("x must have shape [T, F]")
T, F = x.shape
time_idx = np.arange(T)
time_sec = time_idx / float(fps)
arr = x.detach().cpu().numpy()
# --- rolling mean/std ---
if win is None:
win = max(3, fps) # default = ~1 second
half = win // 2
roll_mean, roll_std = rolling_mean_std(arr, win)
roll_t = time_sec[half:half+len(roll_mean)]
# ---------- 1) Time series ----------
fig_ts, axes = plt.subplots(F, 1, figsize=(10, 2.5*F), sharex=True)
if F == 1:
axes = [axes]
for f in range(F):
ax = axes[f]
ax.plot(time_sec, arr[:, f], alpha=0.35, linewidth=1.0, label=f'Feature {f}')
ax.plot(roll_t, roll_mean[:, f], linewidth=2.0, label=f'Rolling mean (w={win})')
ax.fill_between(roll_t,
roll_mean[:, f] - roll_std[:, f],
roll_mean[:, f] + roll_std[:, f],
alpha=0.2, label='±1 std (rolling)')
ax.set_ylabel(f'Feature {f}')
ax.grid(True, linestyle='--', alpha=0.3)
axes[-1].set_xlabel('Time (s)')
axes[0].legend(loc='upper right')
fig_ts.suptitle('Per-feature time series with rolling mean ± std', y=1.02)
fig_ts.tight_layout()
fig_ts.savefig(f"output/{out_prefix}_time_series.png", dpi=200)
# ---------- 2) Heatmap ----------
fig_hm, ax = plt.subplots(figsize=(10, 2.8))
arr_min = arr.min(axis=0, keepdims=True)
arr_max = arr.max(axis=0, keepdims=True)
arr_norm = (arr - arr_min) / (arr_max - arr_min + 1e-12)
im = ax.imshow(arr_norm.T, aspect='auto', interpolation='nearest',
extent=[time_sec[0], time_sec[-1], F-0.5, -0.5])
ax.set_yticks(np.arange(F))
ax.set_yticklabels([f'Feat {f}' for f in range(F)])
ax.set_xlabel('Time (s)')
ax.set_title('Heatmap (per-feature normalized)')
fig_hm.colorbar(im, ax=ax, fraction=0.025, pad=0.02)
fig_hm.tight_layout()
fig_hm.savefig(f"output/{out_prefix}_heatmap.png", dpi=200)
# ---------- 3) Boxplots ----------
fig_box, ax = plt.subplots(figsize=(7, 3.5))
ax.boxplot([arr[:, f] for f in range(F)], showmeans=True)
ax.set_xticklabels([f'Feat {f}' for f in range(F)])
ax.set_ylabel('Value')
ax.set_title('Distribution across time (boxplot per feature)')
ax.grid(True, axis='y', linestyle='--', alpha=0.3)
fig_box.tight_layout()
fig_box.savefig(f"output/{out_prefix}_boxplots.png", dpi=200)
print(f"Saved: {out_prefix}_time_series.png, {out_prefix}_heatmap.png, {out_prefix}_boxplots.png")
def Calc_Eval_table(
TrainModel,
TestDataLoadre:DataLoader,
Youden=False,
modelName="",
):
TrainModel.to(device)
TrainModel.eval()
Predictions = []
Labels = []
with torch.no_grad():
for data, label, mask, key_frame in tqdm(iter(TestDataLoadre)):
data = data.to(device, dtype=torch.float32)
key_frame = key_frame.to(device, dtype=torch.float32)
mask = mask.to(device)
label = label.to(device)
#pred, *_ = TrainModel(data)
pred, _ = TrainModel(data, key_frame, mask)
pred = pred.flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels).int()
#print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
]
if Youden:
for i in range(2):
aucCurve = torchmetrics.ROC("binary").to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
return table
def ensure_list(x):
# Converts tensor/list/tuple to python list
if torch.is_tensor(x):
return x.detach().cpu().tolist()
if isinstance(x, (list, tuple)):
return list(x)
return [x]
def compute_uniform_starts(T: int, L: int, W: int):
stride = (T - L) / max(W - 1, 1)
starts = []
for i in range(W):
s = int(round(i * stride))
s = min(s, T - L)
starts.append(s)
return starts
def extract_and_cluster_windows(
test_loader,
encoder,
device,
out_dir="./cluster_out",
n_clusters=4,
pca_dim=128,
pca_vis_dim=2,
random_state=505,
# fallback params if win_start not in info
T_total=1200, # 20 min @ 1 fps
L_win=None, # will infer from data if None
):
os.makedirs(out_dir, exist_ok=True)
encoder = encoder.to(device).eval()
features = []
case_ids = []
win_idx = []
win_start = []
y_video = []
with torch.no_grad():
for winds, label, _, info in test_loader:
# winds: [B, N, L, C, H, W]
B, N, L, C, H, W = winds.shape
if L_win is None:
L_win = L
# middle frame per window
mid = winds[:, :, L // 2] # [B, N, C, H, W]
mid = mid.reshape(B * N, C, H, W) # [B*N, C, H, W]
f = encoder(mid.to(device, non_blocking=True)) # [B*N, D]
f = torch.nn.functional.normalize(f, dim=1).cpu()
features.append(f)
# labels: video-level (B,) or (B,1). Repeat per window
if torch.is_tensor(label):
label_b = label.detach().cpu().view(B).tolist()
else:
label_b = list(label)
# case_ids: list length B
cids = info.get("case_id", info.get("case", None))
cids = ensure_list(cids)
# window indices
# Preferred: info["win_idx"] is [B, N]
if "win_idx" in info:
widx = info["win_idx"]
if torch.is_tensor(widx):
widx = widx.detach().cpu().view(B, N).tolist()
# flatten by batch item
for b in range(B):
case_ids.extend([cids[b]] * N)
win_idx.extend(widx[b])
y_video.extend([label_b[b]] * N)
else:
# fallback: use 0..N-1
for b in range(B):
case_ids.extend([cids[b]] * N)
win_idx.extend(list(range(N)))
y_video.extend([label_b[b]] * N)
# window start seconds
# Preferred: info["win_start"] is [B, N]
if "win_start" in info:
ws = info["win_start"]
if torch.is_tensor(ws):
ws = ws.detach().cpu().view(B, N).tolist()
for b in range(B):
win_start.extend(ws[b])
else:
# fallback: compute from uniform starts, same for every case
starts = compute_uniform_starts(T=T_total, L=L_win, W=N)
# repeat for B cases
win_start.extend(starts * B)
X = torch.cat(features, dim=0).numpy() # [TotalWindows, D]
df = pd.DataFrame({
"case_id": case_ids,
"win_idx": win_idx,
"win_start_sec": win_start,
"label_nvb": y_video
})
df["win_start_min"] = round (df["win_start_sec"] / 60.0, 4)
# PCA for clustering
X_pca = PCA(n_components=min(pca_dim, X.shape[1]), random_state=random_state).fit_transform(X)
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto")
df["cluster"] = kmeans.fit_predict(X_pca)
sil = silhouette_score(X_pca, df["cluster"].values)
print(f"Silhouette (PCA-{min(pca_dim, X.shape[1])}, K={n_clusters}): {sil:.6f}")
# Save table
csv_path = os.path.join(out_dir, "windows_clusters.csv")
df.to_csv(csv_path, index=False)
print(f"Saved: {csv_path}")
# Plot 1: PCA-2D scatter colored by cluster
X_vis = PCA(n_components=pca_vis_dim, random_state=random_state).fit_transform(X)
plt.figure(figsize=(7, 6))
plt.scatter(X_vis[:, 0], X_vis[:, 1], c=df["cluster"].values, s=6, alpha=0.5)
plt.title(f"Window Embeddings (PCA-{pca_vis_dim}) — KMeans K={n_clusters}")
plt.xlabel("PC1"); plt.ylabel("PC2")
plt.grid(True, alpha=0.2)
pca_path = os.path.join(out_dir, "pca2d_clusters.png")
plt.savefig(pca_path, dpi=200, bbox_inches="tight")
plt.close()
print(f"Saved: {pca_path}")
# Plot 2: cluster timeline for a few cases
# (Pick some cases; you can also pass a list explicitly)
unique_cases = df["case_id"].unique().tolist()
show_cases = unique_cases[:min(6, len(unique_cases))]
_, axes = plt.subplots(len(show_cases), 1, figsize=(10, 2.0 * len(show_cases)), sharex=True)
if len(show_cases) == 1:
axes = [axes]
for ax, cid in zip(axes, show_cases):
sub = df[df["case_id"] == cid].sort_values("win_start_sec")
ax.scatter(sub["win_start_min"], sub["cluster"], s=20)
ax.set_ylabel("Cluster")
ax.set_title(f"Case {cid} — cluster vs time")
ax.grid(True)
axes[-1].set_xlabel("Time (minutes)")
timeline_path = os.path.join(out_dir, "cluster_timelines.png")
plt.tight_layout()
plt.savefig(timeline_path, dpi=200, bbox_inches="tight")
plt.close()
print(f"Saved: {timeline_path}")
return df, sil
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'", required=True)
parser.add_argument("--Fold", type=int, default=0)
parser.add_argument("-lv","--Log_version", type=int, default=None)
parser.add_argument("--Workers", type=int, default=0)
parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
parser.add_argument("--CNN_name", type=str, default=None, )
parser.add_argument("--Temp_Head", type=str, default=None, )
parser.add_argument("-me", "--maxEpochs", type=int, default=None)
parser.add_argument("-b", "--Batch_size", type=int, default=8)
parser.add_argument("--GPU", type=int, default=0)
parser.add_argument("--pre_train", type=int, default=0)
parser.add_argument("-k", "--k_windows", type=int, default=1)
parser.add_argument("--Window_Size", type=int, default=64)
parser.add_argument("--Num_Window", type=int, default=8)
parser.add_argument("--cached_features", type=bool, default=False)
args = parser.parse_args()
setup_seed(2023)
device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")
df = pd.read_csv("../dataset/Dataset_RARP_video/dataset_videos_folds.csv")
FOLD = args.Fold
WORKERS = args.Workers
BATCH_SIZE = args.Batch_size
MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
PRE_TRAIN = args.pre_train != 0
K_WIN = args.k_windows
KEY_FRAME = True if args.Phase != "cluster" else False
WIN_LENGTH = args.Window_Size
NUM_WIN = args.Num_Window
CACHED_FEATURES = args.cached_features
Mean = [0.485, 0.456, 0.406]
Std = [0.229, 0.224, 0.225]
print(f"Fold_{FOLD}")
ckpt_paths = [
Path("./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt"),
Path("./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt"),
]
train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
traintransformT2 = torch.nn.Sequential(
transforms.CenterCrop(300),
transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=0),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
transforms.GaussianBlur(kernel_size=3),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(device)
traintransform_frame = torch.nn.Sequential(
transforms.RandomApply([
transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
transforms.RandomErasing(1.0, value="random")
], 0.3) #small noise
).to(device)
testVal_transform = torch.nn.Sequential(
transforms.CenterCrop(300),
transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
).to(device)
key_frame_transform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
).to(device)
train_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
train_set,
train_mode=True,
num_windows=NUM_WIN,
window_length=WIN_LENGTH,
transform=traintransformT2,
transform_frame=traintransform_frame,
key_frames=KEY_FRAME,
key_frame_transform=key_frame_transform,
load_key_frame_cache=CACHED_FEATURES,
Fold_index=FOLD
)
val_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
val_set,
train_mode=False,
num_windows=NUM_WIN,
window_length=WIN_LENGTH,
transform=testVal_transform,
key_frames=KEY_FRAME,
key_frame_transform=key_frame_transform,
load_key_frame_cache=CACHED_FEATURES,
Fold_index=FOLD
)
test_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
test_set,
train_mode=False,
num_windows=NUM_WIN,
window_length=WIN_LENGTH,
transform=testVal_transform,
key_frames=KEY_FRAME,
key_frame_transform=key_frame_transform,
load_key_frame_cache=CACHED_FEATURES,
Fold_index=FOLD
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
test_loader = DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
LogFileName = f"{args.Log_Name}"
checkPtCallback = [
callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
#callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
]
trainer = PL.Trainer(
precision="32-true" if args.CNN_name == "gsvit" else "16-mixed",
deterministic=True,
accelerator="gpu",
devices=[args.GPU],
#devices=[0, 1], strategy="ddp",
logger=TensorBoardLogger(save_dir=f"./{LogFileName}") if args.Phase == "train" else CSVLogger(save_dir=f"./{LogFileName}/Test", version=args.Log_version),
log_every_n_steps=5,
callbacks=checkPtCallback,
max_epochs=MAX_EPOCHS
)
match(args.Phase):
case "cache_key_frame":
from Models import RARP_NVB_DINO_MultiTask
print (f"Load Export model for the FOLD #{FOLD}")
Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(ckpt_paths[FOLD], map_location=device)
Hybrid_TS.eval()
namelist = ["TRAIN", "VAL", "TEST"]
for _i, _s in enumerate([train_set, val_set, test_set]):
print (f"[{namelist[_i]} Set] of FOLD # {FOLD}")
key_frame_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset(
_s,
key_frames=True,
key_frame_transform=key_frame_transform,
key_frame_only=True,
)
key_frameloader = DataLoader(
key_frame_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=WORKERS,
persistent_workers=WORKERS>0
)
print (f"[SAVE] caching Image features and Soft lables from Expert Model in FOLD #{FOLD}")
with torch.no_grad():
for img, case_id in tqdm(iter(key_frameloader)):
B = img.shape[0]
img = img.to(device, dtype=torch.float)
Soft_label, _, _ = Hybrid_TS(img)
Img_features = torch.cat((Hybrid_TS.last_conv_output_S, Hybrid_TS.last_conv_output_T), dim=1)
Img_features = torch.nn.functional.adaptive_avg_pool2d(Img_features, (1,1)).flatten(1)
for i in range(B):
parent_path = next((r for r in _s if r.get("case") == case_id[i]), None)
parent_path = Path(parent_path["path"]).resolve().parent
parent_path = parent_path / "chache"
parent_path.mkdir(exist_ok=True)
np.savez((parent_path / f"F{FOLD}_{case_id[i]}.npz"), soft_label=Soft_label[i].cpu().numpy(), img_features=Img_features[i].cpu().numpy())
print (f"[DONE] FOLD #{FOLD}")
case "train":
Model = M.RARP_NVB_Multi_MOD_MIL(
num_classes=1,
temporal=args.Temp_Head,
cnn_name=args.CNN_name,
dropout=0.3,
lr=1e-4, #3e-4,
weight_decay=0.1, #0.05
epochs=MAX_EPOCHS,
pre_train=PRE_TRAIN,
Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve()) if not CACHED_FEATURES else None,
FOLD=FOLD,
attn_entropy_target=0.4,
attn_reg_warmup_epochs=5,
attn_reg_weight=0.02
)
print(f"Model Used: {type(Model).__name__}")
print("Train Phase")
trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
case "eval_all":
print("Evaluation Phase")
rows = []
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
for ckpFile in sorted(pathCkptFile.glob("*.ckpt")):
print(ckpFile.name)
#trainer.test(Model, dataloaders=test_loader, ckpt_path=ckpFile)
#Model = M.RARP_NVB_DINO_MultiTask_A5_Video.load_from_checkpoint(ckpFile)
hp_fiel = pathCkptFile.parent / "hparams.yaml"
Model = M.RARP_NVB_Multi_MOD_MIL_TESTMode.load_from_checkpoint(ckpFile, map_location=device, hparams_file=hp_fiel)
trainer.test(Model, dataloaders=test_loader)
#temp = Calc_Eval_table(Model, test_loader, True, ckpFile.name)
temp = Model._test_results
rows += temp
df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])
#df.style.highlight_max(color="red", axis=0)
output_file = Path(f"./{LogFileName}/output.xlsx")
if not output_file.exists():
df.to_excel(output_file, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
else:
with pd.ExcelWriter(output_file, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
df.to_excel(writer, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}")
print("[END] File saved ... ")
case "cluster":
os.environ["OMP_NUM_THREADS"] = "2"
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
Model = M.RARP_NVB_Multi_MOD_MIL(
num_classes=1,
temporal=args.Temp_Head,
cnn_name=args.CNN_name,
dropout=0.3,
lr=1e-4, #3e-4,
weight_decay=0.1, #0.05
epochs=MAX_EPOCHS,
pre_train=PRE_TRAIN,
Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve()) if not CACHED_FEATURES else None,
FOLD=FOLD,
attn_entropy_target=0.4,
attn_reg_warmup_epochs=5,
attn_reg_weight=0.02
)
Model = Model.to(device)
Model.eval()
encoder = Model.cnn
df, sil = extract_and_cluster_windows(test_loader, encoder, device, random_state=505)