import math
from typing import Optional, List
import torch
import torch.nn as nn
import torch.utils.checkpoint as torch_ckp
import torchvision
import torchmetrics
import lightning as L
import van
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from Models import RARP_NVB_DINO_MultiTask, TypeLossFunction
from pathlib import Path
import pandas as pd
#from EfficientViT.GSViT import EfficientViT_GSViT
from EfficientViT.GSViT_RARP import EfficientViT_GSViT



class GRUTemporalHead(nn.Module):
    def __init__(self, feat_dim, hidden=256, num_layers=2, bidirectional=True, dropout=0.2):
        super().__init__()
        self.gru = nn.GRU(
            input_size=feat_dim, hidden_size=hidden, num_layers=num_layers,
            batch_first=True, dropout=dropout if num_layers > 1 else 0.0,
            bidirectional=bidirectional
        )
        self.dropout = nn.Dropout(dropout)
        self.out_dim = hidden * (2 if bidirectional else 1)

    def forward(self, x, mask):
        # x: [B, L, F], mask: [B, L] (bool)
        lengths = mask.sum(dim=1).cpu()
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_out, _ = self.gru(packed)
        seq, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)  # [B, L, out_dim]
        seq = self.dropout(seq)
        return seq  # [B, L, out_dim]
    
class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, dropout=0.2):
        super().__init__()
        pad = (kernel_size - 1) * dilation // 2
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size, padding=pad, dilation=dilation),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Conv1d(out_ch, out_ch, kernel_size, padding=pad, dilation=dilation),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
        )
        self.skip = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x):  # x: [B, C, L]
        y = self.net(x)
        return y + self.skip(x)
    
class TCNTemporalHead(nn.Module):
    def __init__(self, feat_dim, channels: List[int] = [256]*6, dilations: List[int] = [1,2,4,8,16,32],
                 kernel_size=3, dropout=0.2):
        super().__init__()
        
        assert len(channels) == len (dilations)
        
        blocks = []
        in_ch = feat_dim
        for ch, d, in zip(channels, dilations):
            blocks.append(
                TemporalBlock(in_ch, ch, kernel_size=kernel_size, dilation=d, dropout=dropout)
            )
            in_ch = ch
            
        self.net = nn.Sequential(*blocks)
        self.out_dim = channels[-1]
        
    def forward(self, x, mask):
        # x => [B, L, C]
        x = x.transpose(1, 2) #[B, C, L]
        x = self.net(x)
        x = x.transpose(1, 2) #[B, L, C]
        
        return x

class RARP_NVB_Classification_Head(torch.nn.Module):
    def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.activation = activation_fn
        
        if len (layer) == 0:        
            self.head = torch.nn.Linear(in_features, out_features)
        else:
            temp_head = []
            next_input = in_features
            for num in layer:
                temp_head.append(torch.nn.Linear(next_input, num))
                temp_head.append(self.activation)
                temp_head.append(torch.nn.Dropout(0.4))
                next_input = num
                
            temp_head[-1] = torch.nn.Dropout(0.2)
            temp_head.append(torch.nn.Linear(next_input, out_features))
            
            self.head = torch.nn.Sequential(*temp_head)
            del temp_head
    
    def forward(self, x):
        return self.head(x)
    
class TemporalAttentionPool(nn.Module):
    def __init__(self, dim, hidden=128):
        super().__init__()
        
        self.proj = nn.Linear(dim, hidden)
        self.v = nn.Linear(hidden, 1, bias=False)
        
    def forward(self, seq, mask):
        h = torch.tanh(self.proj(seq))
        w = self.v(h).squeeze(-1)
        
        w = w.masked_fill(~mask, -1e9)
        
        attn = torch.softmax(w, dim=1)
                
        pooled = (seq * attn.unsqueeze(-1)).sum(dim=1)
        
        return pooled, attn        
    
class LearnedPositionalEncoding(nn.Module):
    def __init__(self, max_len: int, dim: int):
        super().__init__()
        self.pos_embed = nn.Embedding(max_len, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, L, C]
        """
        B, L, _ = x.shape
        device = x.device
        positions = torch.arange(L, device=device).unsqueeze(0).expand(B, L)  # [B, L]
        pos = self.pos_embed(positions)  # [B, L, C]
        return x + pos
    
class TemporalTransformerHead(nn.Module):
    def __init__(
        self,
        dim: int,             # feature dim C = CNN output dim
        depth: int = 2,       # number of Transformer layers
        n_heads: int = 4,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,
        max_len: int = 1024,
        use_residual: bool = True,
    ):
        super().__init__()
        self.dim = dim
        self.use_residual = use_residual

        self.pos_encoding = LearnedPositionalEncoding(max_len=max_len, dim=dim)

        ff_dim = int(dim * mlp_ratio)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=n_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation="gelu",
            batch_first=True,         # IMPORTANT: [B, L, C]
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=depth,
        )

        self.out_dim = dim

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        x:    [B, L, C]  per-frame CNN features
        mask: [B, L]    bool, True = valid, False = padding
        returns: [B, L, C] sequence features
        """
        x_original = x
        
        x = self.pos_encoding(x)
        src_key_padding_mask = ~mask
        x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        x = x.to(x_original.dtype) + x_original if self.use_residual else x
        
        return x

       
class RARP_NVB_Wind_video (L.LightningModule):
    def __init__(
        self, 
        num_classes: int,
        temporal: str = "gru",            # "gru" or "tcn"
        cnn_name: str = "resnet18",
        dropout: float = 0.2,
        pre_train:bool = False,
        # optimization
        lr: float = 3e-4,
        weight_decay: float = 0.05,
        epochs: int = 50,
        warmup_epochs: int = 3,
        label_smoothing: float = 0.0,
        frizze_cnn:bool = True,
    ):
        super().__init__()
        
        self.save_hyperparameters()
        
        self.attn_reg_weight = 1e-3
        self.last_attn = None
        self.pre_train = pre_train
        
        match(cnn_name.lower()):
            case "resnet18":
                backbone = torchvision.models.resnet18(weights=None if not pre_train else torchvision.models.ResNet18_Weights.DEFAULT)
                feature_dim = backbone.fc.in_features
                backbone.fc = nn.Identity()
                self.layers_to_unfreeze = ["layer3", "layer4"]
            case "resnet34":
                backbone = torchvision.models.resnet34(weights=None if not pre_train else torchvision.models.ResNet34_Weights.DEFAULT)
                feature_dim = backbone.fc.in_features
                backbone.fc = nn.Identity()
                self.layers_to_unfreeze = ["layer3", "layer4"]
            case "resnet50":
                backbone = torchvision.models.resnet50(weights=None if not pre_train else torchvision.models.ResNet50_Weights.DEFAULT)
                feature_dim = backbone.fc.in_features
                backbone.fc = nn.Identity()
                self.layers_to_unfreeze = ["layer3", "layer4"]
            case "van_b1":
                backbone = van.van_b1(pretrained=pre_train)
                feature_dim = backbone.head.in_features
                backbone.head = nn.Identity()
                self.layers_to_unfreeze = ["block3", "block4"]
            case "van_b2":
                backbone = van.van_b2(pretrained=pre_train)
                feature_dim = backbone.head.in_features
                backbone.head = nn.Identity()
                self.layers_to_unfreeze = ["block3", "block4"]
            case "gsvit":
                backbone = EfficientViT_GSViT(str(Path("./EfficientViT/EfficientViT_GSViT.pth").resolve()))
                feature_dim = 384
                self.layers_to_unfreeze = ["blocks2", "blocks3"]
            case _:
                raise NotImplementedError(f"CNN name = '{cnn_name}' is not implemented yet")
        
        if pre_train and frizze_cnn:
            for p in backbone.parameters():
                p.requires_grad = False
                        
        self.cnn = backbone
        
        match(temporal.lower()):
            case "gru":
                temp_head = GRUTemporalHead(feature_dim, hidden=255, num_layers=2, bidirectional=True, dropout=dropout)
                head_dim = temp_head.out_dim
            case "tcn":
                temp_head = TCNTemporalHead(feature_dim, channels=[256]*6, dilations=[1, 2, 4, 8, 16, 32], kernel_size=3, dropout=dropout)
                head_dim = temp_head.out_dim
            case "transf":
                temp_head = TemporalTransformerHead(
                    feature_dim,
                    depth=2,
                    n_heads=4,
                    mlp_ratio=2.0,
                    dropout=0.1,
                    max_len=64, # Window size
                    use_residual=True
                )
                head_dim = temp_head.out_dim
            case _:
                raise NotImplementedError(f"Temporal head = '{temporal}' is not implemented yet")
        
        self.video_feature_dim = head_dim    
        self.temporal_head = temp_head
        self.pool = nn.Sequential(
            nn.LayerNorm(head_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        self.pool_att = TemporalAttentionPool(head_dim, 128)
        
        self.classifier = RARP_NVB_Classification_Head(head_dim, num_classes, [8], nn.SiLU())
        
        self.multi_class = num_classes > 1
        
        if not self.multi_class: # only one class
            self.loss = nn.BCEWithLogitsLoss()
        else:
            self.loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
            
        self.base_lr = lr
        self.weight_decay = weight_decay
        self.total_epochs = epochs
        self.warmup_epochs = warmup_epochs
        
        type_metric = "binary" if not self.multi_class else "multiclass"
        
        self.train_acc = torchmetrics.Accuracy(type_metric)
        self.val_acc = torchmetrics.Accuracy(type_metric)
        self.val_video_acc = torchmetrics.Accuracy(type_metric)
        self.test_acc = torchmetrics.Accuracy(type_metric)
        self.f1ScoreTest = torchmetrics.F1Score(type_metric)
        
        self.val_outputs = [] 
        
    def forward(self, data:torch.Tensor, mask:torch.Tensor):
        
        assert len(data.shape) == 5, "Expeted 5-D tensor in [B, L, C, H, W] format"
        
        B, L, C, H, W = data.shape
        data = data.view(B * L, C, H, W)            # Flaten the video (Big Batch)
        
        cnn_features = self.cnn(data)               # [B*L, F]
        cnn_features = cnn_features.view(B, L, -1)  # separate B from L [B, L, F]
        
        time_features = self.temporal_head(cnn_features, mask)
        
        video_features, self.last_attn = self.pool_att(time_features, mask)
        
        # ---* Mask pooling *---        
        #mask_float = mask.float().unsqueeze(-1)
        #time_features *= mask_float
        #denom = mask_float.sum(1).clamp_min(1.0)
        #video_features = time_features.sum(1) / denom
        
        video_features = self.pool(video_features)
        
        logits = self.classifier(video_features)
        
        return logits
    
    def _attention_entropy(self, mask):
        eps = 1e-8
        attn = self.last_attn * mask
        log_attn = (attn + eps).log()
        ent = -(attn * log_attn).sum(dim=1)
        
        lens = mask.sum(dim=1).clamp_min(1)
        ent = ent / lens
        
        return ent
    
    def _lr_lambda(self, epoch):
        if epoch < self.warmup_epochs:
            lambda_val = float(epoch + 1) / max(1, self.warmup_epochs)
            
        progress = (epoch - self.warmup_epochs) / max(1, (self.total_epochs - self.warmup_epochs))
        lambda_val = 0.5 * (1 + math.cos(math.pi * progress))
        
        return lambda_val
        
    def _shared_step(self, batch, val_step:bool):
        
        match(len(batch)):
            case 4:
                data, label, mask, _ = batch
            case 3:
                data, label, mask = batch
            case _:
                raise ValueError("Batch must be (x,y,mask[,meta])")
            
        logits = self(data, mask)
        
        if not self.multi_class:
            label = label.float()
            logits = logits.flatten()
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits)
            
        attn_loss = 0
        if self.attn_reg_weight > 0:
            attn_loss = self._attention_entropy(mask)
            attn_loss = attn_loss.mean()
            attn_loss *= self.attn_reg_weight
                
        loss = self.loss(logits, label) + attn_loss
        
        return loss, label, logits, [attn_loss] #pred
    
    def on_after_backward(self):
        # Computes global L2 norm of all gradients
        total_norm = 0.0

        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.norm(2)          # L2 norm of this tensor
                total_norm += param_norm.pow(2)      # accumulate square

        total_norm = total_norm.sqrt()               # take sqrt at end

        # Log gradient norm
        self.log("grad_norm", total_norm)

        # Optional vanishing gradient warning
        if total_norm < 1e-8:
            self.log("grad_warning", 1.0) 
       
    def on_train_epoch_start(self):
        lr = self.trainer.optimizers[0].param_groups[0]['lr']
        self.log("lr", lr)    
        
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.log("train_attn_loss", extra_losses[0], on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, True)
        
        self.log("val_wind_loss", loss, on_epoch=True, on_step=False)
        self.log("val_win_attn_loss", extra_losses[0], on_epoch=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_wind_acc", self.val_acc, on_epoch=True, on_step=False)
            
        video_idx = batch[3]["video_idx"]
        
        self.val_outputs.append({
            "logits": predicted_labels.detach().cpu(),
            "targets": true_labels.detach().cpu(),
            "video_idx": video_idx.detach().cpu()
        }) 
        
    def on_validation_epoch_end(self):
        all_logits = torch.cat([o["logits"] for o in self.val_outputs], dim=0)
        all_targets = torch.cat([o["targets"] for o in self.val_outputs], dim=0)
        all_vids = torch.cat([o["video_idx"] for o in self.val_outputs], dim=0)
        
        self.val_outputs.clear()
        
        video_logits = defaultdict(list)
        video_targets = {}
        
        for l, t, v in zip(all_logits, all_targets, all_vids):
            v = int(v.item())
            video_logits[v].append(l)
            video_targets[v] = t
            
        agg_logits = []
        agg_targets = []
        
        for v, parts in video_logits.items():
            avg_logit = torch.stack(parts).mean()
            agg_logits.append(avg_logit)
            agg_targets.append(video_targets[v])
            
        agg_logits = torch.stack(agg_logits)
        agg_targets = torch.stack(agg_targets)
        
        video_loss = self.loss(agg_logits, agg_targets.float())
        self.log("val_video_loss", video_loss, on_epoch=True, on_step=False, prog_bar=True)
        
        self.val_video_acc.update(agg_logits, agg_targets)
        self.log("val_video_acc", self.val_video_acc, on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, _ = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_win_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_win_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
    def configure_optimizers(self):
        trainable_params = [p for p in self.parameters() if p.requires_grad]

        opt = torch.optim.AdamW(
            trainable_params,
            lr=self.base_lr,
            weight_decay=self.weight_decay,
        )

        return opt
        #sch = torch.optim.lr_scheduler.LambdaLR(opt, self._lr_lambda)
        
        #return {
        #    "optimizer": opt,
        #    "lr_scheduler": {
        #        "scheduler": sch,
        #        "interval": "epoch",
        #        "frequency": 1,
        #    },
        #}

class VideoProjector(nn.Module):
    def __init__(self, in_dim, out_dim=256):
        super().__init__()
        mid = max(256, in_dim // 4)
        self.net = nn.Sequential(
            nn.Linear(in_dim, mid), 
            nn.ReLU(), 
            nn.Linear(mid, out_dim)
        )
    def forward(self, x):
        return self.net(x) 
    
class FiLM(nn.Module):
    def __init__(self, dim_img, b_g_size=256):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(dim_img, 2*b_g_size), 
            nn.ReLU(), 
            nn.Linear(2*b_g_size, 2*b_g_size)
        )
        
        nn.init.zeros_(self.mlp[-1].weight)
        nn.init.zeros_(self.mlp[-1].bias)
        
        self.C = b_g_size
        
    def forward(self, x, i): # x:[B,L,C], i:[B,dim_img]
        gb = self.mlp(i)
        gamma, beta = gb.chunk(2, -1) #estract Gamma and Beta from output of mlp
        
        gamma = 1 + 0.1 * torch.tanh(gamma) # Gamma in [0.9, 1.1] range, Activation funtion for Gamma
        beta = 0.1 * torch.tanh(beta) # Beta in [-0.1, 0.1] range, Activation function for Beta
        
        return gamma.unsqueeze(1) * x + beta.unsqueeze(1) # FiLM function 
        
class KD_BCE_Loss(nn.Module):
    def __init__(self, lambda_kd:float, t:float):
        super().__init__()
        
        self.Lambda = lambda_kd
        self.T = t
        
    def forward(self, z_video, z_key):
        
        z_video = z_video.float()
        z_key   = z_key.float()
        
        # teacher probabilities (no gradient)
        with torch.no_grad():
            p_key = torch.sigmoid(z_key / self.T)    # [B] in [0,1]

        # stable BCE with logits; equivalent to BCE(sigmoid(z_video/T), p_key)
        loss = nn.functional.binary_cross_entropy_with_logits(z_video / self.T, p_key)
        
        return self.Lambda * (self.T ** 2) * loss
   
class RARP_NVB_Multi_MOD(RARP_NVB_Wind_video):
    
    def _unfreeze_last_n_layers(self, model:nn.Module):
        if self.pre_train:
            for p in model.parameters():
                p.requires_grad = False

            # unfreeze last n
            for name, module in model.named_children():
                if name in self.layers_to_unfreeze:
                    for p in module.parameters():
                        p.requires_grad = True
    
    def __init__(
        self, 
        num_classes, 
        temporal = "gru", 
        cnn_name = "resnet18", 
        dropout = 0.2, 
        pre_train = False, 
        lr = 0.0003, 
        weight_decay = 0.05, 
        epochs = 50, 
        warmup_epochs = 3, 
        label_smoothing = 0, 
        frizze_cnn = True,
        Hybrid_TS_weights:str = ""
    ):
        super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn)
        
        if Hybrid_TS_weights is not None:
            assert len(Hybrid_TS_weights) > 0, "The Key frame model require pre-trained weigths"
            
            if Hybrid_TS_weights == "train":
                self.Hybrid_TS = RARP_NVB_DINO_MultiTask(
                    TypeLossFunction.BCEWithLogits,
                    std=[40.63141752, 44.26910074, 50.29294373],
                    mean=[30.38144216, 42.03988769, 97.8896116],
                    L1= 1.31E-04,
                    L2= 0,
                    SoftAdptAlgo=0,
                )
            else:
                self.Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(Path(Hybrid_TS_weights), map_location=self.device)
                self.Hybrid_TS.eval()
                for p in self.Hybrid_TS.parameters():
                    p.requires_grad = False 
        else:
            self.Hybrid_TS = None       
        
        self._unfreeze_last_n_layers(self.cnn)
        
        self.img_feature_dim = 1024 #this comes from the Hybrid TS model hyperparameters 
        self.mid_dim = 256
        self.proy_video = VideoProjector(self.video_feature_dim, self.mid_dim)
        self.film = FiLM(self.img_feature_dim, self.mid_dim)
        
        self.pool = nn.Sequential(
            nn.LayerNorm(self.mid_dim),
            nn.Dropout(dropout),
        )
        
        self.classifier = RARP_NVB_Classification_Head(self.mid_dim, num_classes, [], nn.SiLU())
        
        self.kd_lambda = 0.4  # weight of soft distillation loss
        self.kd_T = 2.0       # temperature for logit distillation
        
        self.attn_reg_weight = 0
        self.pool_att = None if self.attn_reg_weight <= 0 else TemporalAttentionPool(self.mid_dim, 128)
        
        self.kd_loss = KD_BCE_Loss(self.kd_lambda, self.kd_T)
        
    def _mask_pooling(self, time_features:torch.Tensor, mask:torch.Tensor):
        mask_float = mask.float().unsqueeze(-1)
        time_features *= mask_float
        denom = mask_float.sum(1).clamp_min(1.0)
        video_features = time_features.sum(1) / denom
        
        return video_features
        
    def forward(self, data:torch.Tensor, key_frame:torch.Tensor, mask:torch.Tensor):
        
        assert len(data.shape) == 5, "Expeted 5-D tensor in [B, L, C, H, W] format"
        
        B, L, C, H, W = data.shape
        data = data.contiguous()
        data = data.view(B * L, C, H, W)            # Flaten the video (Big Batch)
        
        cnn_features = self.cnn(data)               # [B*L, F]
        cnn_features = cnn_features.view(B, L, -1)  # separate B from L [B, L, F]
        
        time_features = self.temporal_head(cnn_features, mask)
        
        # --- FiLM ---
        h_mid = self.proy_video(time_features)
        
        with torch.no_grad():
            pred_key_frame, _, _ = self.Hybrid_TS(key_frame)
            img_features = torch.cat((self.Hybrid_TS.last_conv_output_S, self.Hybrid_TS.last_conv_output_T), dim=1)
            img_features = nn.functional.adaptive_avg_pool2d(img_features, (1,1)).flatten(1) 
            
        h_film = self.film(h_mid, img_features)   
        
        #video_features, self.last_attn = self.pool_att(time_features, mask)
        
        # ---* Mask pooling *---        
        video_features = self._mask_pooling(h_film, mask)
        
        video_features = self.pool(video_features)
        
        logits = self.classifier(video_features)
        
        return logits, pred_key_frame
    
    def _shared_step(self, batch, val_step:bool):
        
        data, label, mask, _, key_frame = batch       
                    
        logits, key_frame_logits = self(data, key_frame, mask)
        
        if not self.multi_class:
            label = label.float()
            logits = logits.flatten()
            pred = torch.sigmoid(logits)
        else:
            pred = torch.softmax(logits)
            
        attn_loss = 0
        if self.attn_reg_weight > 0:
            attn_loss = self._attention_entropy(mask)
            attn_loss = attn_loss.mean()
            attn_loss *= self.attn_reg_weight
                
        attn_loss = self.kd_loss(logits, key_frame_logits.flatten())
                
        loss = self.loss(logits, label) + attn_loss
        
        return loss, label, logits, [attn_loss] #pred
    
    def configure_optimizers(self):
        cnn_params = []
        film_params = []
        encoder_params = []
        head_params = []

        for name, p in self.named_parameters():
            if not p.requires_grad:
                continue
            if "cnn" in name:
                cnn_params.append(p)
            elif "film" in name:
                film_params.append(p)
            elif "temporal_head" in name:
                encoder_params.append(p)
            elif "classifier" in name:
                head_params.append(p)
            else:
                head_params.append(p)

        optimizer = torch.optim.AdamW([
            {"params": cnn_params, "lr": self.base_lr * 0.1},      # slow LR
            {"params": film_params, "lr": self.base_lr},           # main LR
            {"params": encoder_params, "lr": self.base_lr},        # main LR
            {"params": head_params, "lr": self.base_lr * 1.5},     # faster LR
        ])

        return optimizer  

class WindowAttentionMIL(nn.Module):
    def __init__(self, dim, att_dim=128):
        super().__init__()
        self.att_v = nn.Linear(dim, att_dim)
        self.att_u = nn.Linear(att_dim, 1)
        
    def forward(self, H):
        A = torch.tanh(self.att_v(H))
        logits = self.att_u(A)
                    
        alpha = torch.softmax(logits, dim=1)
        v = (alpha * H).sum(dim=1)
        return v, alpha

class AttentionEntropyRangeLoss(nn.Module):
    def __init__(self, target_entropy:float, eps:float = 1e-8):
        super().__init__()
        
        self.H_0 = target_entropy
        self.eps = eps
        
    def forward(self, alpha:torch.Tensor)->torch.Tensor:
        if alpha.dim() == 3:
            alpha = alpha.squeeze(-1)
            
        alpha.clamp(min=self.eps)
        
        H = -(alpha * alpha.log()).sum(dim=1)
        W = alpha.shape[1]
        
        H_norm = H / torch.log(torch.tensor(W, device=alpha.device))
        
        loss = (H_norm - self.H_0) ** 2
        
        return loss.mean()
        
class RARP_NVB_Multi_MOD_MIL(RARP_NVB_Multi_MOD):
    def __init__(
        self, 
        num_classes, 
        temporal="gru",
        cnn_name="resnet18",
        dropout=0.2, 
        pre_train=False, 
        lr=0.0003, 
        weight_decay=0.05, 
        epochs=50, 
        warmup_epochs=3, 
        label_smoothing=0, 
        frizze_cnn=True, 
        Hybrid_TS_weights = "",
        attn_reg_weight:float=0.02, 
        attn_entropy_target:float=0.40, 
        attn_reg_warmup_epochs:int=5,
        FOLD:int=None
    ):
        super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn, Hybrid_TS_weights)
        
        self.win_mil_att = WindowAttentionMIL(self.mid_dim, att_dim=128)
        self.win_pool = nn.Sequential(
            nn.LayerNorm(self.mid_dim),
            nn.Dropout(dropout),
        )
        
        self.attn_reg_weight = attn_reg_weight
        self.attn_reg_warmup_epochs = attn_reg_warmup_epochs
        self.attn_loss = AttentionEntropyRangeLoss(attn_entropy_target)
    
    def _attn_reg_lambda(self) -> float:
        # Linear warmup from 0 to attn_reg_weight over attn_reg_warmup_epochs
        if self.attn_reg_warmup_epochs <= 0:
            return float(self.attn_reg_weight)
        t = min(1.0, self.current_epoch / self.attn_reg_warmup_epochs)
        return float(self.attn_reg_weight * t)
        
    def _frame_wise_pass(self, data:torch.Tensor, key_frame:torch.Tensor, mask:torch.Tensor):
        
        B_N, L, C, H, W = data.shape
        
        data = data.view(B_N * L, C, H, W)              # Flaten the video (Big Batch)
        data = data.contiguous()
        
        B = key_frame.shape[0]
        n_win = B_N // B
        
        cnn_features = self.cnn(data)                   # [B_N*L, F]
        cnn_features = cnn_features.view(B_N, L, -1)    # [B_N, L, F]
        
        time_features = self.temporal_head(cnn_features, mask)
        
        # --- FiLM ---
        h_mid = self.proy_video(time_features)
        
        if self.Hybrid_TS is not None:
            with torch.no_grad():
                pred_key_frame, _, _ = self.Hybrid_TS(key_frame)
                img_features = torch.cat((self.Hybrid_TS.last_conv_output_S, self.Hybrid_TS.last_conv_output_T), dim=1)
                img_features = nn.functional.adaptive_avg_pool2d(img_features, (1,1)).flatten(1) 
        else:
            pred_key_frame = None
            img_features = key_frame
        
        img_features = (
            img_features
            .unsqueeze(1)                        # [B, 1, F_img]
            .expand(B, n_win, img_features.size(1))  # [B, n_win, F_img]
            .contiguous()
            .view(B_N, -1)                       # [B_N, F_img]
        )    
        
        h_film = self.film(h_mid, img_features)
        
        # --- Mask Pooling ---
        video_features = self._mask_pooling(h_film, mask)
        
        video_features = self.pool(video_features)
        
        return video_features, pred_key_frame 
        
    def forward(self, data:torch.Tensor, key_frame:torch.Tensor, mask:torch.Tensor):
        B, N, L, C, H, W = data.shape
        BM, NM, V = mask.shape
        
        data = data.view(B*N, L, C, H, W) # Flaten bags of windows
        mask = mask.view(BM*NM, V)
        data = data.contiguous()
        mask = mask.contiguous()
        
        video_features, pred_key_frame = self._frame_wise_pass(data, key_frame, mask) #[B*N, D]
        
        # --- Window-wise pass ---
        video_features = video_features.view(B, N, -1)
        vid_emb, alpha = self.win_mil_att(video_features)
        vid_emb = self.win_pool(vid_emb)
        
        logits = self.classifier(vid_emb)
        
        return logits, pred_key_frame, alpha
    
    def _shared_step(self, batch, val_stpe:bool=False):
        
        match (len(batch)):
            case 5:        
                data, label, mask, key_frame, meta = batch
            case 6:
                data, label, mask, key_frame, soft_label, meta = batch
            case _:
                raise NotImplementedError()
        
        logits, key_frame_logits, alpha_w = self(data, key_frame, mask)       
        
        key_frame_logits = soft_label if key_frame_logits is None else key_frame_logits
        
        label = label.float()
        logits = logits.flatten()
        
        soft_loss = self.kd_loss(logits, key_frame_logits.flatten())
        hard_loss = self.loss(logits, label)
        attn_win_loss = self._attn_reg_lambda() * self.attn_loss(alpha_w)
        
        total_loss = hard_loss + soft_loss + attn_win_loss
        
        return total_loss, label, logits, [soft_loss, alpha_w, meta["case_id"], attn_win_loss]
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.log("train_soft_loss", extra_losses[0], on_epoch=True)
        self.log("train_attn_loss", extra_losses[3], on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, True)
        
        val_main_loss = loss - extra_losses[3]  # remove attention regularizer
        
        self.log("val_loss", val_main_loss, on_epoch=True, on_step=False)
        self.log("val_soft_loss", extra_losses[0], on_epoch=True)
        self.log("val_attn_loss", extra_losses[3], on_epoch=True)
        self.log("val_total_loss", loss, on_epoch=True, on_step=False)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, _ = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def on_validation_epoch_end(self):
        pass
    
class RARP_NVB_M3_TwoBranchModel(RARP_NVB_Multi_MOD_MIL):
    def __init__(
        self, num_classes, temporal="gru", cnn_name="resnet18", dropout=0.2, pre_train=False, lr=0.0003, weight_decay=0.05, epochs=50, warmup_epochs=3, label_smoothing=0, frizze_cnn=True, Hybrid_TS_weights="", attn_reg_weight = 0.02, attn_entropy_target = 0.4, attn_reg_warmup_epochs = 5, FOLD = None,
        GPU_list:list = []
    ):
        super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn, Hybrid_TS_weights, attn_reg_weight, attn_entropy_target, attn_reg_warmup_epochs, FOLD)
        
        self.Expert_Branch_device = torch.device(GPU_list[0])
        self.Video_Branch_device = torch.device(GPU_list[1])
        
        self.automatic_optimization = False
        
        self.Hybrid_TS.to(self.Expert_Branch_device)
        
    def on_fit_start(self):
        self.Hybrid_TS.to(self.Expert_Branch_device)
        
    def configure_optimizers(self):
        
        cnn_params = []
        film_params = []
        encoder_params = []
        head_params = []

        for name, p in self.named_parameters():
            if not p.requires_grad:
                continue
            
            if "cnn" in name:
                cnn_params.append(p)
            elif "film" in name:
                film_params.append(p)
            elif "temporal_head" in name:
                encoder_params.append(p)
            elif "classifier" in name:
                head_params.append(p)
            elif "Hybrid_TS" in name:
                continue
            else:
                head_params.append(p)
        
        vid_optimizer = torch.optim.AdamW([
            {"params": cnn_params, "lr": self.base_lr * 0.1},      # slow LR
            {"params": film_params, "lr": self.base_lr},           # main LR
            {"params": encoder_params, "lr": self.base_lr},        # main LR
            {"params": head_params, "lr": self.base_lr * 1.5},     # faster LR
        ])
        
        img_optimizer = torch.optim.AdamW(self.Hybrid_TS.parameters(), lr=self.Hybrid_TS.lr)  #, weight_decay=self.Lambda_L2
        
        return [img_optimizer, vid_optimizer]
    
    def training_step(self, batch, batch_idx):
        opt_img, opt_vid = self.optimizers()
        
        match (len(batch)):
            case 5:        
                data, label, mask, key_frame, meta = batch
            case 6:
                data, label, mask, key_frame, soft_label, meta = batch
            case _:
                raise NotImplementedError()
            
        img_label = label.copy()
        img_label = img_label.to(self.Expert_Branch_device)
        key_frame = [k.to(self.Expert_Branch_device, non_blocking=True) for k in key_frame]
        
        img_pred, img_TS_DINO, img_recons = self.Hybrid_TS(key_frame, val_step=False)
        S_Dino, T_Dino = img_TS_DINO
        
        img_pred = img_pred.flatten()
        original_img = torch.cat([key_frame[0].float() for _ in range(len(T_Dino))], dim=0)
        
        Img_loss_Dino = self.Hybrid_TS.lossFN_DINO(S_Dino, T_Dino)
        Img_loss_HL = self.Hybrid_TS.lossFN(img_pred, img_label)
        Img_loss_img = self.Hybrid_TS.ReconstructionLoss(img_recons, original_img)
        
        if self.Hybrid_TS.Lambda_L1 is not None:
            Img_loss_HL += self.Hybrid_TS._calc_L1(self.Hybrid_TS.clasiffier.parameters())
            
        self.Hybrid_TS.loss_history["loss_DINO"].append(Img_loss_Dino.item())
        self.Hybrid_TS.loss_history["loss_Reconstruction"].append(Img_loss_img.item())
        self.Hybrid_TS.loss_history["loss_Binary"].append(Img_loss_HL.item())
        
        Img_Expert_loss = self.Hybrid_TS.weights[0] * Img_loss_Dino + self.Hybrid_TS.weights[1] * Img_loss_img + self.Hybrid_TS.weights[2] * Img_loss_HL
        
        opt_img.zero_grad(set_to_none=True)
        self.manual_backward(Img_Expert_loss)
        opt_img.step()
        
        img_features = torch.cat((self.Hybrid_TS.last_conv_output_S.detach(), self.Hybrid_TS.last_conv_output_T.detach()), dim=1)
        
        ###  TO-DO
        
        
        
        
        
class RARP_NVB_Multi_MOD_MIL_TESTMode(RARP_NVB_Multi_MOD_MIL):
    def __init__(self, num_classes, temporal="gru", cnn_name="resnet18", dropout=0.2, pre_train=False, lr=0.0003, weight_decay=0.05, epochs=50, warmup_epochs=3, label_smoothing=0, frizze_cnn=True, Hybrid_TS_weights="", attn_reg_weight = 0.02, attn_entropy_target = 0.4, attn_reg_warmup_epochs = 5, FOLD = None):
        super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn, Hybrid_TS_weights, attn_reg_weight, attn_entropy_target, attn_reg_warmup_epochs, FOLD)
        
        self.FOLD = FOLD
        self.Predictions = []
        self.Labels = []
        self._test_results = None
        self.loaded_ckpt_epoch = None
        
        self.test_records = []
    
    def on_load_checkpoint(self, checkpoint: dict):
        self.loaded_ckpt_epoch = checkpoint.get("epoch", None)
        
    def on_test_epoch_start(self):
        self.Predictions = []
        self.Labels = []
        self._test_results = None
        self.test_records = []
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, extra = self._shared_step(batch, True)
        
        probs = torch.sigmoid(predicted_labels)
        B = probs.shape[0]
        
        self.Predictions.append(probs)
        self.Labels.append(true_labels)
        
        for b in range(B):
            rec = {
                "case_id": extra[2][b].item(),
                "y_true": int(true_labels[b].item()),
                "y_pred": (probs[b] > 0.5).int().item(),
                "prob": float(probs[b].item()),
                "alpha": extra[1][b].flatten().cpu().numpy()
            }
            
            self.test_records.append(rec)
        
    def on_test_epoch_end(self):
        out_dir = Path(self.trainer.default_root_dir) / f"test_reports/FOLD_{self.FOLD}"
        out_dir.mkdir(parents=True, exist_ok=True)
        
        rows = []
        for r in self.test_records:
            row = {
                "case_id": r["case_id"],
                "y_true": r["y_true"],
                "y_pred": r["y_pred"],
                "prob": r["prob"],
            }
            # store attention weights as separate columns alpha_0...alpha_{W-1}
            alpha = r["alpha"]
            for i, a in enumerate(alpha):
                row[f"alpha_{i:02d}"] = float(a)
            rows.append(row)
            
        df = pd.DataFrame(rows)
        df.to_csv(out_dir / f"mil_test_predictions_epoch{self.loaded_ckpt_epoch}.csv", index=False)
        
        predictions = torch.cat(self.Predictions)
        labels = torch.cat(self.Labels).int()
        
        device = self.device
        
        acc = torchmetrics.Accuracy('binary').to(device)(predictions, labels)
        precision = torchmetrics.Precision('binary').to(device)(predictions, labels)
        recall = torchmetrics.Recall('binary').to(device)(predictions, labels)
        auc = torchmetrics.AUROC('binary').to(device)(predictions, labels)
        f1Score = torchmetrics.F1Score('binary').to(device)(predictions, labels)
        specificty = torchmetrics.Specificity("binary").to(device)(predictions, labels)
        
        table = [
            ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
        ]
        
        for i in range(2):
            aucCurve = torchmetrics.ROC("binary").to(device)
            fpr, tpr, thhols = aucCurve(predictions, labels)
            index = torch.argmax(tpr - fpr)
            th2 = (recall + specificty - 1).item()
            th2 = 0.5 if th2 <= 0 else th2
            th1 = thhols[index].item() if i == 0 else th2
            accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(predictions, labels)
            precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(predictions, labels)
            recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(predictions, labels)
            specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(predictions, labels)
            f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(predictions, labels)
            #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
            #cm2.update(Predictions, Labels)
            #_, ax = cm2.plot()
            #ax.set_title(f"NVB Classifier (th={th1:.4f})")
            table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", self.loaded_ckpt_epoch])
            
        self._test_results = table