Newer
Older
RARP / Models_video.py
@delAguila delAguila on 19 Dec 25 KB Update 2025-12-19
import math
from typing import Optional, List
import torch
import torch.nn as nn
import torch.utils.checkpoint as torch_ckp
import torchvision
import torchmetrics
import torchmetrics.classification
import lightning as L
import van
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from Models import RARP_NVB_DINO_MultiTask
from pathlib import Path



class GRUTemporalHead(nn.Module):
    def __init__(self, feat_dim, hidden=256, num_layers=2, bidirectional=True, dropout=0.2):
        super().__init__()
        self.gru = nn.GRU(
            input_size=feat_dim, hidden_size=hidden, num_layers=num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
            bidirectional=bidirectional
        )
        self.dropout = nn.Dropout(dropout)
        self.out_dim = hidden * (2 if bidirectional else 1)

    def forward(self, x, mask):
        # x: [B, L, F], mask: [B, L] (bool)
        lengths = mask.sum(dim=1).cpu()
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_out, _ = self.gru(packed)
        seq, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)  # [B, L, out_dim]
        seq = self.dropout(seq)
        return seq  # [B, L, out_dim]
    
class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, dropout=0.2):
        super().__init__()
        pad = (kernel_size - 1) * dilation // 2
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size, padding=pad, dilation=dilation),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Conv1d(out_ch, out_ch, kernel_size, padding=pad, dilation=dilation),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
        )
        self.skip = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x):  # x: [B, C, L]
        y = self.net(x)
        return y + self.skip(x)
    
class TCNTemporalHead(nn.Module):
    def __init__(self, feat_dim, channels: List[int] = [256]*6, dilations: List[int] = [1,2,4,8,16,32],
                 kernel_size=3, dropout=0.2):
        super().__init__()
        
        assert len(channels) == len (dilations)
        
        blocks = []
        in_ch = feat_dim
        for ch, d, in zip(channels, dilations):
            blocks.append(
                TemporalBlock(in_ch, ch, kernel_size=kernel_size, dilation=d, dropout=dropout)
            )
            in_ch = ch
            
        self.net = nn.Sequential(*blocks)
        self.out_dim = channels[-1]
        
    def forward(self, x, mask):
        # x => [B, L, C]
        x = x.transpose(1, 2) #[B, C, L]
        x = self.net(x)
        x = x.transpose(1, 2) #[B, L, C]
        
        return x

class RARP_NVB_Classification_Head(torch.nn.Module):
    def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.activation = activation_fn
        
        if len (layer) == 0:        
            self.head = torch.nn.Linear(in_features, out_features)
        else:
            temp_head = []
            next_input = in_features
            for num in layer:
                temp_head.append(torch.nn.Linear(next_input, num))
                temp_head.append(self.activation)
                temp_head.append(torch.nn.Dropout(0.4))
                next_input = num
                
            temp_head[-1] = torch.nn.Dropout(0.2)
            temp_head.append(torch.nn.Linear(next_input, out_features))
            
            self.head = torch.nn.Sequential(*temp_head)
            del temp_head
    
    def forward(self, x):
        return self.head(x)
    
class TemporalAttentionPool(nn.Module):
    def __init__(self, dim, hidden=128):
        super().__init__()
        
        self.proj = nn.Linear(dim, hidden)
        self.v = nn.Linear(hidden, 1, bias=False)
        
    def forward(self, seq, mask):
        h = torch.tanh(self.proj(seq))
        w = self.v(h).squeeze(-1)
        
        w = w.masked_fill(~mask, -1e9)
        
        attn = torch.softmax(w, dim=1)
                
        pooled = (seq * attn.unsqueeze(-1)).sum(dim=1)
        
        return pooled, attn        
    
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, max_len: int, dim: int):
        super().__init__()
        self.pos_embed = nn.Embedding(max_len, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, L, C]
        """
        B, L, _ = x.shape
        device = x.device
        positions = torch.arange(L, device=device).unsqueeze(0).expand(B, L)  # [B, L]
        pos = self.pos_embed(positions)  # [B, L, C]
        return x + pos
    
class TemporalTransformerHead(nn.Module):
    def __init__(
        self,
        dim: int,             # feature dim C = CNN output dim
        depth: int = 2,       # number of Transformer layers
        n_heads: int = 4,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,
        max_len: int = 1024,
        use_residual: bool = True,
    ):
        super().__init__()
        self.dim = dim
        self.use_residual = use_residual

        self.pos_encoding = LearnedPositionalEncoding(max_len=max_len, dim=dim)

        ff_dim = int(dim * mlp_ratio)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=n_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation="gelu",
            batch_first=True,         # IMPORTANT: [B, L, C]
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=depth,
        )

        self.out_dim = dim

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        x:    [B, L, C]  per-frame CNN features
        mask: [B, L]    bool, True = valid, False = padding
        returns: [B, L, C] sequence features
        """
        x_original = x
        
        x = self.pos_encoding(x)
        src_key_padding_mask = ~mask
        x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        x = x.to(x_original.dtype) + x_original if self.use_residual else x
        
        return x

       
class RARP_NVB_Wind_video (L.LightningModule):
    def __init__(
        self, 
        num_classes: int,
        temporal: str = "gru",            # "gru" or "tcn"
        cnn_name: str = "resnet18",
        dropout: float = 0.2,
        pre_train:bool = False,
        # optimization
        lr: float = 3e-4,
        weight_decay: float = 0.05,
        epochs: int = 50,
        warmup_epochs: int = 3,
        label_smoothing: float = 0.0,
        frizze_cnn:bool = True,
    ):
        super().__init__()
        
        self.save_hyperparameters()
        
        self.attn_reg_weight = 1e-3
        self.last_attn = None
        
        match(cnn_name.lower()):
            case "resnet18":
                backbone = torchvision.models.resnet18(weights=None if not pre_train else torchvision.models.ResNet18_Weights.DEFAULT)
                feature_dim = backbone.fc.in_features
                backbone.fc = nn.Identity()
                self.layers_to_unfreeze = ["layer3", "layer4"]
            case "resnet34":
                backbone = torchvision.models.resnet34(weights=None if not pre_train else torchvision.models.ResNet34_Weights.DEFAULT)
                feature_dim = backbone.fc.in_features
                backbone.fc = nn.Identity()
                self.layers_to_unfreeze = ["layer3", "layer4"]
            case "resnet50":
                backbone = torchvision.models.resnet50(weights=None if not pre_train else torchvision.models.ResNet50_Weights.DEFAULT)
                feature_dim = backbone.fc.in_features
                backbone.fc = nn.Identity()
                self.layers_to_unfreeze = ["layer3", "layer4"]
            case "van_b1":
                backbone = van.van_b1(pretrained=pre_train)
                feature_dim = backbone.head.in_features
                backbone.head = nn.Identity()
                self.layers_to_unfreeze = ["block3", "block4"]
            case "van_b2":
                backbone = van.van_b2(pretrained=pre_train)
                feature_dim = backbone.head.in_features
                backbone.head = nn.Identity()
                self.layers_to_unfreeze = ["block3", "block4"]
            case _:
                raise NotImplementedError(f"CNN name = '{cnn_name}' is not implemented yet")
        
        if pre_train and frizze_cnn:
            for p in backbone.parameters():
                p.requires_grad = False
                        
        self.cnn = backbone
        
        match(temporal.lower()):
            case "gru":
                temp_head = GRUTemporalHead(feature_dim, hidden=255, num_layers=2, bidirectional=True, dropout=dropout)
                head_dim = temp_head.out_dim
            case "tcn":
                temp_head = TCNTemporalHead(feature_dim, channels=[256]*6, dilations=[1, 2, 4, 8, 16, 32], kernel_size=3, dropout=dropout)
                head_dim = temp_head.out_dim
            case "transf":
                temp_head = TemporalTransformerHead(
                    feature_dim,
                    depth=2,
                    n_heads=4,
                    mlp_ratio=2.0,
                    dropout=0.1,
                    max_len=64, # Window size
                    use_residual=True
                )
                head_dim = temp_head.out_dim
            case _:
                raise NotImplementedError(f"Temporal head = '{temporal}' is not implemented yet")
        
        self.video_feature_dim = head_dim    
        self.temporal_head = temp_head
        self.pool = nn.Sequential(
            nn.LayerNorm(head_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        self.pool_att = TemporalAttentionPool(head_dim, 128)
        
        self.classifier = RARP_NVB_Classification_Head(head_dim, num_classes, [8], nn.SiLU())
        
        self.multi_class = num_classes > 1
        
        if not self.multi_class: # only one class
            self.loss = nn.BCEWithLogitsLoss()
        else:
            self.loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
            
        self.base_lr = lr
        self.weight_decay = weight_decay
        self.total_epochs = epochs
        self.warmup_epochs = warmup_epochs
        
        type_metric = "binary" if not self.multi_class else "multiclass"
        
        self.train_acc = torchmetrics.Accuracy(type_metric)
        self.val_acc = torchmetrics.Accuracy(type_metric)
        self.val_video_acc = torchmetrics.Accuracy(type_metric)
        self.test_acc = torchmetrics.Accuracy(type_metric)
        self.f1ScoreTest = torchmetrics.F1Score(type_metric)
        
        self.val_outputs = [] 
        
    def forward(self, data:torch.Tensor, mask:torch.Tensor):
        
        assert len(data.shape) == 5, "Expeted 5-D tensor in [B, L, C, H, W] format"
        
        B, L, C, H, W = data.shape
        data = data.view(B * L, C, H, W)            # Flaten the video (Big Batch)
        
        cnn_features = self.cnn(data)               # [B*L, F]
        cnn_features = cnn_features.view(B, L, -1)  # separate B from L [B, L, F]
        
        time_features = self.temporal_head(cnn_features, mask)
        
        video_features, self.last_attn = self.pool_att(time_features, mask)
        
        # ---* Mask pooling *---        
        #mask_float = mask.float().unsqueeze(-1)
        #time_features *= mask_float
        #denom = mask_float.sum(1).clamp_min(1.0)
        #video_features = time_features.sum(1) / denom
        
        video_features = self.pool(video_features)
        
        logits = self.classifier(video_features)
        
        return logits
    
    def _attention_entropy(self, mask):
        eps = 1e-8
        attn = self.last_attn * mask
        log_attn = (attn + eps).log()
        ent = -(attn * log_attn).sum(dim=1)
        
        lens = mask.sum(dim=1).clamp_min(1)
        ent = ent / lens
        
        return ent
    
    def _lr_lambda(self, epoch):
        if epoch < self.warmup_epochs:
            lambda_val = float(epoch + 1) / max(1, self.warmup_epochs)
            
        progress = (epoch - self.warmup_epochs) / max(1, (self.total_epochs - self.warmup_epochs))
        lambda_val = 0.5 * (1 + math.cos(math.pi * progress))
        
        return lambda_val
        
    def _shared_step(self, batch, val_step:bool):
        
        match(len(batch)):
            case 4:
                data, label, mask, _ = batch
            case 3:
                data, label, mask = batch
            case _:
                raise ValueError("Batch must be (x,y,mask[,meta])")
            
        logits = self(data, mask)
        
        if not self.multi_class:
            label = label.float()
            logits = logits.flatten()
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits)
            
        attn_loss = 0
        if self.attn_reg_weight > 0:
            attn_loss = self._attention_entropy(mask)
            attn_loss = attn_loss.mean()
            attn_loss *= self.attn_reg_weight
                
        loss = self.loss(logits, label) + attn_loss
        
        return loss, label, logits, [attn_loss] #pred
    
    def on_after_backward(self):
        # Computes global L2 norm of all gradients
        total_norm = 0.0

        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.norm(2)          # L2 norm of this tensor
                total_norm += param_norm.pow(2)      # accumulate square

        total_norm = total_norm.sqrt()               # take sqrt at end

        # Log gradient norm
        self.log("grad_norm", total_norm)

        # Optional vanishing gradient warning
        if total_norm < 1e-8:
            self.log("grad_warning", 1.0) 
       
    def on_train_epoch_start(self):
        lr = self.trainer.optimizers[0].param_groups[0]['lr']
        self.log("lr", lr)    
        
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.log("train_attn_loss", extra_losses[0], on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, True)
        
        self.log("val_wind_loss", loss, on_epoch=True, on_step=False)
        self.log("val_win_attn_loss", extra_losses[0], on_epoch=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_wind_acc", self.val_acc, on_epoch=True, on_step=False)
            
        video_idx = batch[3]["video_idx"]
        
        self.val_outputs.append({
            "logits": predicted_labels.detach().cpu(),
            "targets": true_labels.detach().cpu(),
            "video_idx": video_idx.detach().cpu()
        }) 
        
    def on_validation_epoch_end(self):
        all_logits = torch.cat([o["logits"] for o in self.val_outputs], dim=0)
        all_targets = torch.cat([o["targets"] for o in self.val_outputs], dim=0)
        all_vids = torch.cat([o["video_idx"] for o in self.val_outputs], dim=0)
        
        self.val_outputs.clear()
        
        video_logits = defaultdict(list)
        video_targets = {}
        
        for l, t, v in zip(all_logits, all_targets, all_vids):
            v = int(v.item())
            video_logits[v].append(l)
            video_targets[v] = t
            
        agg_logits = []
        agg_targets = []
        
        for v, parts in video_logits.items():
            avg_logit = torch.stack(parts).mean()
            agg_logits.append(avg_logit)
            agg_targets.append(video_targets[v])
            
        agg_logits = torch.stack(agg_logits)
        agg_targets = torch.stack(agg_targets)
        
        video_loss = self.loss(agg_logits, agg_targets.float())
        self.log("val_video_loss", video_loss, on_epoch=True, on_step=False, prog_bar=True)
        
        self.val_video_acc.update(agg_logits, agg_targets)
        self.log("val_video_acc", self.val_video_acc, on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, _ = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_win_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_win_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
    def configure_optimizers(self):
        trainable_params = [p for p in self.parameters() if p.requires_grad]

        opt = torch.optim.AdamW(
            trainable_params,
            lr=self.base_lr,
            weight_decay=self.weight_decay,
        )

        return opt
        #sch = torch.optim.lr_scheduler.LambdaLR(opt, self._lr_lambda)
        
        #return {
        #    "optimizer": opt,
        #    "lr_scheduler": {
        #        "scheduler": sch,
        #        "interval": "epoch",
        #        "frequency": 1,
        #    },
        #}

class VideoProjector(nn.Module):
    def __init__(self, in_dim, out_dim=256):
        super().__init__()
        mid = max(256, in_dim // 4)
        self.net = nn.Sequential(
            nn.Linear(in_dim, mid), 
            nn.ReLU(), 
            nn.Linear(mid, out_dim)
        )
    def forward(self, x):
        return self.net(x) 
    
class FiLM(nn.Module):
    def __init__(self, dim_img, b_g_size=256):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(dim_img, 2*b_g_size), 
            nn.ReLU(), 
            nn.Linear(2*b_g_size, 2*b_g_size)
        )
        
        nn.init.zeros_(self.mlp[-1].weight)
        nn.init.zeros_(self.mlp[-1].bias)
        
        self.C = b_g_size
        
    def forward(self, x, i): # x:[B,L,C], i:[B,dim_img]
        gb = self.mlp(i)
        gamma, beta = gb.chunk(2, -1) #estract Gamma and Beta from output of mlp
        
        gamma = 1 + 0.1 * torch.tanh(gamma) # Gamma in [0.9, 1.1] range, Activation funtion for Gamma
        beta = 0.1 * torch.tanh(beta) # Beta in [-0.1, 0.1] range, Activation function for Beta
        
        return gamma.unsqueeze(1) * x + beta.unsqueeze(1) # FiLM function 
        
class KD_BCE_Loss(nn.Module):
    def __init__(self, lambda_kd:float, t:float):
        super().__init__()
        
        self.Lambda = lambda_kd
        self.T = t
        
    def forward(self, z_video, z_key):
        p_video = torch.sigmoid(z_video / self.T)
        p_key = torch.sigmoid(z_key / self.T).detach()
        
        loss = nn.functional.binary_cross_entropy(p_video, p_key)
        
        return self.Lambda * (self.T ** 2) * loss
   
class RARP_NVB_Multi_MOD(RARP_NVB_Wind_video):
    
    def _unfreeze_last_n_layers(self, model:nn.Module):
        # collect layer names
        #all_layers = [name for name, _ in model.named_children()]
        # last n layers
        #layers_to_unfreeze = all_layers[-self.num_layers_cnn:]
        # freeze everything
        
        for p in model.parameters():
            p.requires_grad = False

        # unfreeze last n
        for name, module in model.named_children():
            if name in self.layers_to_unfreeze:
                for p in module.parameters():
                    p.requires_grad = True
    
    def __init__(
        self, 
        num_classes, 
        temporal = "gru", 
        cnn_name = "resnet18", 
        dropout = 0.2, 
        pre_train = False, 
        lr = 0.0003, 
        weight_decay = 0.05, 
        epochs = 50, 
        warmup_epochs = 3, 
        label_smoothing = 0, 
        frizze_cnn = True,
        Hybrid_TS_weights:str = ""
    ):
        super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn)
        
        assert len(Hybrid_TS_weights) > 0, "The Key frame model require pre-trained weigths"
        
        self.Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(Path(Hybrid_TS_weights), map_location=self.device)
        self.Hybrid_TS.eval()
        for p in self.Hybrid_TS.parameters():
            p.requires_grad = False        
        
        self._unfreeze_last_n_layers(self.cnn)
        
        self.img_feature_dim = 1024 #this comes from the Hybrid TS model hyperparameters 
        self.mid_dim = 256
        self.proy_video = VideoProjector(self.video_feature_dim, self.mid_dim)
        self.film = FiLM(self.img_feature_dim, self.mid_dim)
        
        self.pool = nn.Sequential(
            nn.LayerNorm(self.mid_dim),
            nn.Dropout(dropout),
        )
        
        self.classifier = RARP_NVB_Classification_Head(self.mid_dim, num_classes, [], nn.SiLU())
        
        self.kd_lambda = 0.4  # weight of soft distillation loss
        self.kd_T = 2.0       # temperature for logit distillation
        
        self.attn_reg_weight = 0
        self.pool_att = None if self.attn_reg_weight <= 0 else TemporalAttentionPool(self.mid_dim, 128)
        
        self.kd_loss = KD_BCE_Loss(self.kd_lambda, self.kd_T)
        
    def _mask_pooling(self, time_features:torch.Tensor, mask:torch.Tensor):
        mask_float = mask.float().unsqueeze(-1)
        time_features *= mask_float
        denom = mask_float.sum(1).clamp_min(1.0)
        video_features = time_features.sum(1) / denom
        
        return video_features
        
    def forward(self, data:torch.Tensor, key_frame:torch.Tensor, mask:torch.Tensor):
        
        assert len(data.shape) == 5, "Expeted 5-D tensor in [B, L, C, H, W] format"
        
        B, L, C, H, W = data.shape
        data = data.contiguous()
        data = data.view(B * L, C, H, W)            # Flaten the video (Big Batch)
        
        cnn_features = self.cnn(data)               # [B*L, F]
        cnn_features = cnn_features.view(B, L, -1)  # separate B from L [B, L, F]
        
        time_features = self.temporal_head(cnn_features, mask)
        
        # --- FiLM ---
        h_mid = self.proy_video(time_features)
        
        with torch.no_grad():
            pred_key_frame, _, _ = self.Hybrid_TS(key_frame)
            img_features = torch.cat((self.Hybrid_TS.last_conv_output_S, self.Hybrid_TS.last_conv_output_T), dim=1)
            img_features = nn.functional.adaptive_avg_pool2d(img_features, (1,1)).flatten(1) 
            
        h_film = self.film(h_mid, img_features)   
        
        #video_features, self.last_attn = self.pool_att(time_features, mask)
        
        # ---* Mask pooling *---        
        video_features = self._mask_pooling(h_film, mask)
        
        video_features = self.pool(video_features)
        
        logits = self.classifier(video_features)
        
        return logits, pred_key_frame
    
    def _shared_step(self, batch, val_step:bool):
        
        data, label, mask, _, key_frame = batch       
                    
        logits, key_frame_logits = self(data, key_frame, mask)
        
        if not self.multi_class:
            label = label.float()
            logits = logits.flatten()
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits)
            
        attn_loss = 0
        if self.attn_reg_weight > 0:
            attn_loss = self._attention_entropy(mask)
            attn_loss = attn_loss.mean()
            attn_loss *= self.attn_reg_weight
                
        attn_loss = self.kd_loss(logits, key_frame_logits.flatten())
                
        loss = self.loss(logits, label) + attn_loss
        
        return loss, label, logits, [attn_loss] #pred
    
    def configure_optimizers(self):
        cnn_params = []
        film_params = []
        encoder_params = []
        head_params = []

        for name, p in self.named_parameters():
            if not p.requires_grad:
                continue
            if "cnn" in name:
                cnn_params.append(p)
            elif "film" in name:
                film_params.append(p)
            elif "temporal_head" in name:
                encoder_params.append(p)
            elif "classifier" in name:
                head_params.append(p)
            else:
                head_params.append(p)

        optimizer = torch.optim.AdamW([
            {"params": cnn_params, "lr": self.base_lr * 0.1},      # slow LR
            {"params": film_params, "lr": self.base_lr},           # main LR
            {"params": encoder_params, "lr": self.base_lr},        # main LR
            {"params": head_params, "lr": self.base_lr * 1.5},     # faster LR
        ])

        return optimizer  
        
class RARP_NVB_Multi_MOD_A1(RARP_NVB_Multi_MOD):
    def __init__(self, num_classes, temporal="gru", cnn_name="resnet18", dropout=0.2, pre_train=False, lr=0.0003, weight_decay=0.05, epochs=50, warmup_epochs=3, label_smoothing=0, frizze_cnn=True, Hybrid_TS_weights = ""):
        super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn, Hybrid_TS_weights)