import math
from typing import Optional, List
import torch
import torch.nn as nn
import torch.utils.checkpoint as torch_ckp
import torchvision
import torchmetrics
import torchmetrics.classification
import lightning as L
import van
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from Models import RARP_NVB_DINO_MultiTask
from pathlib import Path
class GRUTemporalHead(nn.Module):
def __init__(self, feat_dim, hidden=256, num_layers=2, bidirectional=True, dropout=0.2):
super().__init__()
self.gru = nn.GRU(
input_size=feat_dim, hidden_size=hidden, num_layers=num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
bidirectional=bidirectional
)
self.dropout = nn.Dropout(dropout)
self.out_dim = hidden * (2 if bidirectional else 1)
def forward(self, x, mask):
# x: [B, L, F], mask: [B, L] (bool)
lengths = mask.sum(dim=1).cpu()
packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
packed_out, _ = self.gru(packed)
seq, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) # [B, L, out_dim]
seq = self.dropout(seq)
return seq # [B, L, out_dim]
class TemporalBlock(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, dropout=0.2):
super().__init__()
pad = (kernel_size - 1) * dilation // 2
self.net = nn.Sequential(
nn.Conv1d(in_ch, out_ch, kernel_size, padding=pad, dilation=dilation),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Conv1d(out_ch, out_ch, kernel_size, padding=pad, dilation=dilation),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
)
self.skip = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x): # x: [B, C, L]
y = self.net(x)
return y + self.skip(x)
class TCNTemporalHead(nn.Module):
def __init__(self, feat_dim, channels: List[int] = [256]*6, dilations: List[int] = [1,2,4,8,16,32],
kernel_size=3, dropout=0.2):
super().__init__()
assert len(channels) == len (dilations)
blocks = []
in_ch = feat_dim
for ch, d, in zip(channels, dilations):
blocks.append(
TemporalBlock(in_ch, ch, kernel_size=kernel_size, dilation=d, dropout=dropout)
)
in_ch = ch
self.net = nn.Sequential(*blocks)
self.out_dim = channels[-1]
def forward(self, x, mask):
# x => [B, L, C]
x = x.transpose(1, 2) #[B, C, L]
x = self.net(x)
x = x.transpose(1, 2) #[B, L, C]
return x
class RARP_NVB_Classification_Head(torch.nn.Module):
def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.activation = activation_fn
if len (layer) == 0:
self.head = torch.nn.Linear(in_features, out_features)
else:
temp_head = []
next_input = in_features
for num in layer:
temp_head.append(torch.nn.Linear(next_input, num))
temp_head.append(self.activation)
temp_head.append(torch.nn.Dropout(0.4))
next_input = num
temp_head[-1] = torch.nn.Dropout(0.2)
temp_head.append(torch.nn.Linear(next_input, out_features))
self.head = torch.nn.Sequential(*temp_head)
del temp_head
def forward(self, x):
return self.head(x)
class TemporalAttentionPool(nn.Module):
def __init__(self, dim, hidden=128):
super().__init__()
self.proj = nn.Linear(dim, hidden)
self.v = nn.Linear(hidden, 1, bias=False)
def forward(self, seq, mask):
h = torch.tanh(self.proj(seq))
w = self.v(h).squeeze(-1)
w = w.masked_fill(~mask, -1e9)
attn = torch.softmax(w, dim=1)
pooled = (seq * attn.unsqueeze(-1)).sum(dim=1)
return pooled, attn
class LearnedPositionalEncoding(nn.Module):
def __init__(self, max_len: int, dim: int):
super().__init__()
self.pos_embed = nn.Embedding(max_len, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [B, L, C]
"""
B, L, _ = x.shape
device = x.device
positions = torch.arange(L, device=device).unsqueeze(0).expand(B, L) # [B, L]
pos = self.pos_embed(positions) # [B, L, C]
return x + pos
class TemporalTransformerHead(nn.Module):
def __init__(
self,
dim: int, # feature dim C = CNN output dim
depth: int = 2, # number of Transformer layers
n_heads: int = 4,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
max_len: int = 1024,
use_residual: bool = True,
):
super().__init__()
self.dim = dim
self.use_residual = use_residual
self.pos_encoding = LearnedPositionalEncoding(max_len=max_len, dim=dim)
ff_dim = int(dim * mlp_ratio)
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=n_heads,
dim_feedforward=ff_dim,
dropout=dropout,
activation="gelu",
batch_first=True, # IMPORTANT: [B, L, C]
)
self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=depth,
)
self.out_dim = dim
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
x: [B, L, C] per-frame CNN features
mask: [B, L] bool, True = valid, False = padding
returns: [B, L, C] sequence features
"""
x_original = x
x = self.pos_encoding(x)
src_key_padding_mask = ~mask
x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
x = x.to(x_original.dtype) + x_original if self.use_residual else x
return x
class RARP_NVB_Wind_video (L.LightningModule):
def __init__(
self,
num_classes: int,
temporal: str = "gru", # "gru" or "tcn"
cnn_name: str = "resnet18",
dropout: float = 0.2,
pre_train:bool = False,
# optimization
lr: float = 3e-4,
weight_decay: float = 0.05,
epochs: int = 50,
warmup_epochs: int = 3,
label_smoothing: float = 0.0,
frizze_cnn:bool = True,
):
super().__init__()
self.save_hyperparameters()
self.attn_reg_weight = 1e-3
self.last_attn = None
match(cnn_name.lower()):
case "resnet18":
backbone = torchvision.models.resnet18(weights=None if not pre_train else torchvision.models.ResNet18_Weights.DEFAULT)
feature_dim = backbone.fc.in_features
backbone.fc = nn.Identity()
self.layers_to_unfreeze = ["layer3", "layer4"]
case "resnet34":
backbone = torchvision.models.resnet34(weights=None if not pre_train else torchvision.models.ResNet34_Weights.DEFAULT)
feature_dim = backbone.fc.in_features
backbone.fc = nn.Identity()
self.layers_to_unfreeze = ["layer3", "layer4"]
case "resnet50":
backbone = torchvision.models.resnet50(weights=None if not pre_train else torchvision.models.ResNet50_Weights.DEFAULT)
feature_dim = backbone.fc.in_features
backbone.fc = nn.Identity()
self.layers_to_unfreeze = ["layer3", "layer4"]
case "van_b1":
backbone = van.van_b1(pretrained=pre_train)
feature_dim = backbone.head.in_features
backbone.head = nn.Identity()
self.layers_to_unfreeze = ["block3", "block4"]
case "van_b2":
backbone = van.van_b2(pretrained=pre_train)
feature_dim = backbone.head.in_features
backbone.head = nn.Identity()
self.layers_to_unfreeze = ["block3", "block4"]
case _:
raise NotImplementedError(f"CNN name = '{cnn_name}' is not implemented yet")
if pre_train and frizze_cnn:
for p in backbone.parameters():
p.requires_grad = False
self.cnn = backbone
match(temporal.lower()):
case "gru":
temp_head = GRUTemporalHead(feature_dim, hidden=255, num_layers=2, bidirectional=True, dropout=dropout)
head_dim = temp_head.out_dim
case "tcn":
temp_head = TCNTemporalHead(feature_dim, channels=[256]*6, dilations=[1, 2, 4, 8, 16, 32], kernel_size=3, dropout=dropout)
head_dim = temp_head.out_dim
case "transf":
temp_head = TemporalTransformerHead(
feature_dim,
depth=2,
n_heads=4,
mlp_ratio=2.0,
dropout=0.1,
max_len=64, # Window size
use_residual=True
)
head_dim = temp_head.out_dim
case _:
raise NotImplementedError(f"Temporal head = '{temporal}' is not implemented yet")
self.video_feature_dim = head_dim
self.temporal_head = temp_head
self.pool = nn.Sequential(
nn.LayerNorm(head_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.pool_att = TemporalAttentionPool(head_dim, 128)
self.classifier = RARP_NVB_Classification_Head(head_dim, num_classes, [8], nn.SiLU())
self.multi_class = num_classes > 1
if not self.multi_class: # only one class
self.loss = nn.BCEWithLogitsLoss()
else:
self.loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
self.base_lr = lr
self.weight_decay = weight_decay
self.total_epochs = epochs
self.warmup_epochs = warmup_epochs
type_metric = "binary" if not self.multi_class else "multiclass"
self.train_acc = torchmetrics.Accuracy(type_metric)
self.val_acc = torchmetrics.Accuracy(type_metric)
self.val_video_acc = torchmetrics.Accuracy(type_metric)
self.test_acc = torchmetrics.Accuracy(type_metric)
self.f1ScoreTest = torchmetrics.F1Score(type_metric)
self.val_outputs = []
def forward(self, data:torch.Tensor, mask:torch.Tensor):
assert len(data.shape) == 5, "Expeted 5-D tensor in [B, L, C, H, W] format"
B, L, C, H, W = data.shape
data = data.view(B * L, C, H, W) # Flaten the video (Big Batch)
cnn_features = self.cnn(data) # [B*L, F]
cnn_features = cnn_features.view(B, L, -1) # separate B from L [B, L, F]
time_features = self.temporal_head(cnn_features, mask)
video_features, self.last_attn = self.pool_att(time_features, mask)
# ---* Mask pooling *---
#mask_float = mask.float().unsqueeze(-1)
#time_features *= mask_float
#denom = mask_float.sum(1).clamp_min(1.0)
#video_features = time_features.sum(1) / denom
video_features = self.pool(video_features)
logits = self.classifier(video_features)
return logits
def _attention_entropy(self, mask):
eps = 1e-8
attn = self.last_attn * mask
log_attn = (attn + eps).log()
ent = -(attn * log_attn).sum(dim=1)
lens = mask.sum(dim=1).clamp_min(1)
ent = ent / lens
return ent
def _lr_lambda(self, epoch):
if epoch < self.warmup_epochs:
lambda_val = float(epoch + 1) / max(1, self.warmup_epochs)
progress = (epoch - self.warmup_epochs) / max(1, (self.total_epochs - self.warmup_epochs))
lambda_val = 0.5 * (1 + math.cos(math.pi * progress))
return lambda_val
def _shared_step(self, batch, val_step:bool):
match(len(batch)):
case 4:
data, label, mask, _ = batch
case 3:
data, label, mask = batch
case _:
raise ValueError("Batch must be (x,y,mask[,meta])")
logits = self(data, mask)
if not self.multi_class:
label = label.float()
logits = logits.flatten()
pred = torch.sigmoid(logits)
else:
pred = torch.softmax(logits)
attn_loss = 0
if self.attn_reg_weight > 0:
attn_loss = self._attention_entropy(mask)
attn_loss = attn_loss.mean()
attn_loss *= self.attn_reg_weight
loss = self.loss(logits, label) + attn_loss
return loss, label, logits, [attn_loss] #pred
def on_after_backward(self):
# Computes global L2 norm of all gradients
total_norm = 0.0
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.norm(2) # L2 norm of this tensor
total_norm += param_norm.pow(2) # accumulate square
total_norm = total_norm.sqrt() # take sqrt at end
# Log gradient norm
self.log("grad_norm", total_norm)
# Optional vanishing gradient warning
if total_norm < 1e-8:
self.log("grad_warning", 1.0)
def on_train_epoch_start(self):
lr = self.trainer.optimizers[0].param_groups[0]['lr']
self.log("lr", lr)
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.log("train_attn_loss", extra_losses[0], on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, True)
self.log("val_wind_loss", loss, on_epoch=True, on_step=False)
self.log("val_win_attn_loss", extra_losses[0], on_epoch=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_wind_acc", self.val_acc, on_epoch=True, on_step=False)
video_idx = batch[3]["video_idx"]
self.val_outputs.append({
"logits": predicted_labels.detach().cpu(),
"targets": true_labels.detach().cpu(),
"video_idx": video_idx.detach().cpu()
})
def on_validation_epoch_end(self):
all_logits = torch.cat([o["logits"] for o in self.val_outputs], dim=0)
all_targets = torch.cat([o["targets"] for o in self.val_outputs], dim=0)
all_vids = torch.cat([o["video_idx"] for o in self.val_outputs], dim=0)
self.val_outputs.clear()
video_logits = defaultdict(list)
video_targets = {}
for l, t, v in zip(all_logits, all_targets, all_vids):
v = int(v.item())
video_logits[v].append(l)
video_targets[v] = t
agg_logits = []
agg_targets = []
for v, parts in video_logits.items():
avg_logit = torch.stack(parts).mean()
agg_logits.append(avg_logit)
agg_targets.append(video_targets[v])
agg_logits = torch.stack(agg_logits)
agg_targets = torch.stack(agg_targets)
video_loss = self.loss(agg_logits, agg_targets.float())
self.log("val_video_loss", video_loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_video_acc.update(agg_logits, agg_targets)
self.log("val_video_acc", self.val_video_acc, on_epoch=True, on_step=False)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels, _ = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_win_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_win_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
trainable_params = [p for p in self.parameters() if p.requires_grad]
opt = torch.optim.AdamW(
trainable_params,
lr=self.base_lr,
weight_decay=self.weight_decay,
)
return opt
#sch = torch.optim.lr_scheduler.LambdaLR(opt, self._lr_lambda)
#return {
# "optimizer": opt,
# "lr_scheduler": {
# "scheduler": sch,
# "interval": "epoch",
# "frequency": 1,
# },
#}
class VideoProjector(nn.Module):
def __init__(self, in_dim, out_dim=256):
super().__init__()
mid = max(256, in_dim // 4)
self.net = nn.Sequential(
nn.Linear(in_dim, mid),
nn.ReLU(),
nn.Linear(mid, out_dim)
)
def forward(self, x):
return self.net(x)
class FiLM(nn.Module):
def __init__(self, dim_img, b_g_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(dim_img, 2*b_g_size),
nn.ReLU(),
nn.Linear(2*b_g_size, 2*b_g_size)
)
nn.init.zeros_(self.mlp[-1].weight)
nn.init.zeros_(self.mlp[-1].bias)
self.C = b_g_size
def forward(self, x, i): # x:[B,L,C], i:[B,dim_img]
gb = self.mlp(i)
gamma, beta = gb.chunk(2, -1) #estract Gamma and Beta from output of mlp
gamma = 1 + 0.1 * torch.tanh(gamma) # Gamma in [0.9, 1.1] range, Activation funtion for Gamma
beta = 0.1 * torch.tanh(beta) # Beta in [-0.1, 0.1] range, Activation function for Beta
return gamma.unsqueeze(1) * x + beta.unsqueeze(1) # FiLM function
class KD_BCE_Loss(nn.Module):
def __init__(self, lambda_kd:float, t:float):
super().__init__()
self.Lambda = lambda_kd
self.T = t
def forward(self, z_video, z_key):
p_video = torch.sigmoid(z_video / self.T)
p_key = torch.sigmoid(z_key / self.T).detach()
loss = nn.functional.binary_cross_entropy(p_video, p_key)
return self.Lambda * (self.T ** 2) * loss
class RARP_NVB_Multi_MOD(RARP_NVB_Wind_video):
def _unfreeze_last_n_layers(self, model:nn.Module):
# collect layer names
#all_layers = [name for name, _ in model.named_children()]
# last n layers
#layers_to_unfreeze = all_layers[-self.num_layers_cnn:]
# freeze everything
for p in model.parameters():
p.requires_grad = False
# unfreeze last n
for name, module in model.named_children():
if name in self.layers_to_unfreeze:
for p in module.parameters():
p.requires_grad = True
def __init__(
self,
num_classes,
temporal = "gru",
cnn_name = "resnet18",
dropout = 0.2,
pre_train = False,
lr = 0.0003,
weight_decay = 0.05,
epochs = 50,
warmup_epochs = 3,
label_smoothing = 0,
frizze_cnn = True,
Hybrid_TS_weights:str = ""
):
super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn)
assert len(Hybrid_TS_weights) > 0, "The Key frame model require pre-trained weigths"
self.Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(Path(Hybrid_TS_weights), map_location=self.device)
self.Hybrid_TS.eval()
for p in self.Hybrid_TS.parameters():
p.requires_grad = False
self._unfreeze_last_n_layers(self.cnn)
self.img_feature_dim = 1024 #this comes from the Hybrid TS model hyperparameters
self.mid_dim = 256
self.proy_video = VideoProjector(self.video_feature_dim, self.mid_dim)
self.film = FiLM(self.img_feature_dim, self.mid_dim)
self.pool = nn.Sequential(
nn.LayerNorm(self.mid_dim),
nn.Dropout(dropout),
)
self.classifier = RARP_NVB_Classification_Head(self.mid_dim, num_classes, [], nn.SiLU())
self.kd_lambda = 0.4 # weight of soft distillation loss
self.kd_T = 2.0 # temperature for logit distillation
self.attn_reg_weight = 0
self.pool_att = None if self.attn_reg_weight <= 0 else TemporalAttentionPool(self.mid_dim, 128)
self.kd_loss = KD_BCE_Loss(self.kd_lambda, self.kd_T)
def _mask_pooling(self, time_features:torch.Tensor, mask:torch.Tensor):
mask_float = mask.float().unsqueeze(-1)
time_features *= mask_float
denom = mask_float.sum(1).clamp_min(1.0)
video_features = time_features.sum(1) / denom
return video_features
def forward(self, data:torch.Tensor, key_frame:torch.Tensor, mask:torch.Tensor):
assert len(data.shape) == 5, "Expeted 5-D tensor in [B, L, C, H, W] format"
B, L, C, H, W = data.shape
data = data.contiguous()
data = data.view(B * L, C, H, W) # Flaten the video (Big Batch)
cnn_features = self.cnn(data) # [B*L, F]
cnn_features = cnn_features.view(B, L, -1) # separate B from L [B, L, F]
time_features = self.temporal_head(cnn_features, mask)
# --- FiLM ---
h_mid = self.proy_video(time_features)
with torch.no_grad():
pred_key_frame, _, _ = self.Hybrid_TS(key_frame)
img_features = torch.cat((self.Hybrid_TS.last_conv_output_S, self.Hybrid_TS.last_conv_output_T), dim=1)
img_features = nn.functional.adaptive_avg_pool2d(img_features, (1,1)).flatten(1)
h_film = self.film(h_mid, img_features)
#video_features, self.last_attn = self.pool_att(time_features, mask)
# ---* Mask pooling *---
video_features = self._mask_pooling(h_film, mask)
video_features = self.pool(video_features)
logits = self.classifier(video_features)
return logits, pred_key_frame
def _shared_step(self, batch, val_step:bool):
data, label, mask, _, key_frame = batch
logits, key_frame_logits = self(data, key_frame, mask)
if not self.multi_class:
label = label.float()
logits = logits.flatten()
pred = torch.sigmoid(logits)
else:
pred = torch.softmax(logits)
attn_loss = 0
if self.attn_reg_weight > 0:
attn_loss = self._attention_entropy(mask)
attn_loss = attn_loss.mean()
attn_loss *= self.attn_reg_weight
attn_loss = self.kd_loss(logits, key_frame_logits.flatten())
loss = self.loss(logits, label) + attn_loss
return loss, label, logits, [attn_loss] #pred
def configure_optimizers(self):
cnn_params = []
film_params = []
encoder_params = []
head_params = []
for name, p in self.named_parameters():
if not p.requires_grad:
continue
if "cnn" in name:
cnn_params.append(p)
elif "film" in name:
film_params.append(p)
elif "temporal_head" in name:
encoder_params.append(p)
elif "classifier" in name:
head_params.append(p)
else:
head_params.append(p)
optimizer = torch.optim.AdamW([
{"params": cnn_params, "lr": self.base_lr * 0.1}, # slow LR
{"params": film_params, "lr": self.base_lr}, # main LR
{"params": encoder_params, "lr": self.base_lr}, # main LR
{"params": head_params, "lr": self.base_lr * 1.5}, # faster LR
])
return optimizer
class RARP_NVB_Multi_MOD_A1(RARP_NVB_Multi_MOD):
def __init__(self, num_classes, temporal="gru", cnn_name="resnet18", dropout=0.2, pre_train=False, lr=0.0003, weight_decay=0.05, epochs=50, warmup_epochs=3, label_smoothing=0, frizze_cnn=True, Hybrid_TS_weights = ""):
super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn, Hybrid_TS_weights)