Newer
Older
RARP / Models.py
@delAguila delAguila on 19 Dec 158 KB Update 2025-12-19
import math
from typing import Any, Union
import torch
import torch.utils.checkpoint as torch_ckp
import torchvision
import torchmetrics
import torchmetrics.classification
import lightning as L
import lightning.pytorch.callbacks as callbk
from enum import Enum
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

import timm
import van
import numpy as np
from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt
from noah import NOAH
import piq


def js_divergence_sigmoid(p_logits, q_logits):
    p_probs = torch.sigmoid(p_logits)
    q_probs = torch.sigmoid(q_logits)
    
    m = 0.5 * (p_probs + q_probs)
    
    bce_p_m = torch.nn.functional.binary_cross_entropy(m, p_probs, reduction='none')
    bce_q_m = torch.nn.functional.binary_cross_entropy(m, q_probs, reduction='none')
    
    js_div = 0.5 * (bce_p_m + bce_q_m)
    
    return js_div

def getNVL(obj:dict, key, default):
    return default if obj.get(key) is None else obj.get(key)

class Decoder (torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
        
        self.Activation = torch.nn.ReLU#torch.nn.GELU
        self.input_channels = input_channels
        
        blocks = []
        
        inCh = input_channels
        
        blocks.append(torch.nn.Conv2d(inCh, inCh, kernel_size=3, stride=1, padding=1, bias=False))
        blocks.append(torch.nn.BatchNorm2d(inCh))
        blocks.append(self.Activation())
        
        for i, outCh in enumerate(hidden_channels):
            blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=4, stride=2, padding=1, bias=False))
            blocks.append(torch.nn.BatchNorm2d(outCh))
            blocks.append(self.Activation())
            blocks.append(torch.nn.Conv2d(outCh, outCh, kernel_size=3, stride=1, padding=1, bias=False))
            blocks.append(torch.nn.BatchNorm2d(outCh))
            blocks.append(self.Activation())
            
            #blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=3, stride=2, padding=1, output_padding=1))
            #blocks.append(torch.nn.BatchNorm2d(outCh))
            #blocks.append(self.Activation())
            inCh = outCh
        
        blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=4, stride=2, padding=1, bias=False))
        blocks.append(torch.nn.BatchNorm2d(output_channels))
        blocks.append(self.Activation())
        blocks.append(torch.nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=False))
        blocks.append(torch.nn.BatchNorm2d(output_channels))
        blocks.append(self.Activation())
        
        #blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
                
        self.decoder = torch.nn.Sequential(*blocks)
        
    def forward(self, x):
        
        x = self.decoder(x)
        return x
    
class DynamicDecoder(torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], drop_out:float = None):
        super(DynamicDecoder, self).__init__()

        # Ensure the number of hidden channels matches the number of blocks
        assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."

        layers = []
        in_channels = input_channels
        
        # Loop to create the decoder blocks
        for out_channels in hidden_channels:
            layers.append(torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(torch.nn.BatchNorm2d(out_channels))
            layers.append(torch.nn.ReLU(inplace=True))
            if drop_out is not None:
                layers.append(torch.nn.Dropout(drop_out))
            in_channels = out_channels

        # Final layer to get the output image
        layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
        #layers.append(torch.nn.Sigmoid())  # To get pixel values between 0 and 1

        # Combine all layers into a Sequential module
        self.decoder = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)    
    
class DecoderUnet(torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3):
        super().__init__()

        self.dropout = torch.nn.Dropout2d(0.4)

        self.upConv_0 = torch.nn.ConvTranspose2d(input_channels, 512, kernel_size=2, stride=2)
        self.decoder_0 = self._conv_block(1024, 512) 
        
        self.upConv_1 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2)
        self.decoder_1 = self._conv_block(640, 320)
        
        self.upConv_2 = torch.nn.ConvTranspose2d(320, 128, kernel_size=2, stride=2)
        self.decoder_2 = self._conv_block(256, 128)
        
        self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder_3 = self._conv_block(128, 64) 
        
        self.upConv_4 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder_4 = self._conv_block(32, 16)
        
        self.enc3_upsample = torch.nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
        self.enc2_upsample = torch.nn.ConvTranspose2d(320, 320, kernel_size=2, stride=2)
        self.enc1_upsample = torch.nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
        self.enc0_upsample = torch.nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        
        self.last_conv = torch.nn.Conv2d(16, output_channels, kernel_size=1)
        
    def _conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
        )

    def forward(self, x):
        encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = x
                
        decoder_l3 = self.upConv_0(btlneck)
        encoder_l3 = self.enc3_upsample(encoder_l3)
        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
        decoder_l3 = self.decoder_0(decoder_l3)
        
        decoder_l3 = self.dropout(decoder_l3)
        
        decoder_l2 = self.upConv_1(decoder_l3)
        encoder_l2 = self.enc2_upsample(encoder_l2)
        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
        decoder_l2 = self.decoder_1(decoder_l2)
        
        decoder_l2 = self.dropout(decoder_l2)
        
        decoder_l1 = self.upConv_2(decoder_l2)
        encoder_l1 = self.enc1_upsample(encoder_l1)
        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
        decoder_l1 = self.decoder_2(decoder_l1)
        
        decoder_l1 = self.dropout(decoder_l1)
        
        decoder_l0 = self.upConv_3(decoder_l1)
        encoder_l0 = self.enc0_upsample(encoder_l0)
        decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1)
        decoder_l0 = self.decoder_3(decoder_l0)
        
        decoder_last = self.upConv_4(decoder_l0)
        decoder_last = self.decoder_4(decoder_last)
        
        return self.last_conv(decoder_last)           

class ModelsList(Enum):
    RestNet50 = 1
    DenseNet169 = 2
    Efficientnet_b0 = 3
    MobileNetV2 = 4
    Inception3 = 5
    ResNeXt_50_32x4d = 6
    RestNet50_Droput = 7

class TypeLossFunction(Enum):
    CrossEntropy = 0
    BCEWithLogits = 1
    HingeLoss = 2
    FocalLoss = 3
    ContrastiveLoss = 4
    
class ReconstructionLoss(torch.nn.Module):
    def __init__(self, l1_weight=1.0, tv_weight=1e-5):
        super().__init__()
        
        self.l1w = l1_weight
        self.tvw = tv_weight
        
    def _tv_loss(self, x):
        dh = (x[...,1:,:] - x[...,:-1,:]).abs().mean()
        dw = (x[...,:,1:] - x[...,:,:-1]).abs().mean()
        
        return self.tvw * (dh + dw)
    
    def forward(self, pred, target):
        l1 = (pred - target).abs().mean()
        tv = self._tv_loss(pred) if self.tvw > 0 else 0.0
        
        return self.l1w * l1 + tv

class FeatureAlignmentLoss(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def forward(self, F_student, F_teacher):
        fS = torch.nn.functional.normalize(F_student, p=2, dim=-1)
        fT = torch.nn.functional.normalize(F_teacher, p=2, dim=-1)
        
        cos_Sim = torch.sum(fS * fT, dim=-1)
        
        loss = 1 - cos_Sim
        
        return loss.mean()
    
class CosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_epochs: int,
        max_epochs: int,
        warmup_start_lr: float = 0.00001,
        eta_min: float = 0.00001,
        last_epoch: int = -1,
    ):
        """
        Args:
            optimizer (torch.optim.Optimizer):
                最適化手法インスタンス
            warmup_epochs (int):
                linear warmupを行うepoch数
            max_epochs (int):
                cosine曲線の終了に用いる 学習のepoch数
            warmup_start_lr (float):
                linear warmup 0 epoch目の学習率
            eta_min (float):
                cosine曲線の下限
            last_epoch (int):
                cosine曲線の位相オフセット
        学習率をmax_epochsに至るまでコサイン曲線に沿ってスケジュールする
        epoch 0からwarmup_epochsまでの学習曲線は線形warmupがかかる
        https://pytorch-lightning-bolts.readthedocs.io/en/stable/schedulers/warmup_cosine_annealing.html
        """
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.warmup_start_lr = warmup_start_lr
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch, verbose=False)
        return None

    def get_lr(self):
        if self.last_epoch == 0:
            return [self.warmup_start_lr] * len(self.base_lrs)
        if self.last_epoch < self.warmup_epochs:
            return [
                group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]
        if self.last_epoch == self.warmup_epochs:
            return self.base_lrs
        if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
            return [
                group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]

        return [
            (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
            / (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)))
            * (group["lr"] - self.eta_min)
            + self.eta_min
            for group in self.optimizer.param_groups
        ]    

class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=1.0, distance:int = 0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.TypeDistance = distance

    def forward(self, output1, output2, label):
        # Calcula la distancia euclidiana entre las dos salidas
        Distance = torch.nn.functional.pairwise_distance(output1, output2)
        
        # Calcula la pérdida contrastiva
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(Distance, 2) +
            (label) * torch.pow(torch.clamp(self.margin - Distance, min=0.0), 2)
        )
        
        return loss_contrastive    

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook()
        
    def hook(self):
        def forward_hook(module, Input, Output):
            self.activations = Output
        
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]
            
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)
        
    def generate_cam(self, data:torch.Tensor, target_class):
        self.model.zero_grad()
        data.requires_grad= True
        output = self.model(data).flatten()
        loss = torch.nn.functional.binary_cross_entropy_with_logits(output, target_class)
        loss.backward()
        
        pooled_Grad = torch.mean(self.gradients, dim=[0, 2, 3])
        for i in range (pooled_Grad.size(0)):
            self.activations[:, i, :, :] *= pooled_Grad[i]
            
        cam = torch.mean(self.activations, dim=1).squeeze()
        cam = np.maximum(cam.detach().numpy(), 0)
        cam = (cam - cam.min()) / (cam.max() - cam.min())
        return cam

class UNet_RN18(torch.nn.Module):
    def _conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
        )
    
    def _hook_fn(self, module, input, output):
        self.feature_maps.append(output)    
    
    def _register_encoder_hooks(self):
        for layer in self.list_blocks:
            self.hooks.append(layer.register_forward_hook(self._hook_fn))
    
    def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.hooks = []
        self.feature_maps = []
        
        self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        self.encoder_base.fc = torch.nn.Identity()
        
        #for parms in self.encoder_base.parameters():
        #    parms.requires_grad = False
        
        self.list_blocks = [
            self.encoder_base.conv1,
            self.encoder_base.layer1,
            self.encoder_base.layer2,
            self.encoder_base.layer3,
            self.encoder_base.layer4
        ]
        
        self._register_encoder_hooks()
        
        self.dropout = torch.nn.Dropout2d(0.4)
        
        self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        
        self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)

        self.decoder_0 = self._conv_block(512, 256)     #0
        self.decoder_1 = self._conv_block(256, 128)     #1
        self.decoder_2 = self._conv_block(128, 64)      #2
        self.decoder_3 = self._conv_block(32 + 64, 32)  #3
        
        self.decoder_extra = self._conv_block(16, 16)
        
        self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1)
        
    def forward(self, x):
        self.feature_maps = []
        
        _ = self.encoder_base(x) # forward to encoder and call hooks 
        
        encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps
                
        decoder_l3 = self.upConv_0(btlneck) 
        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) 
        decoder_l3 = self.decoder_0(decoder_l3)    
        
        decoder_l3 = self.dropout(decoder_l3)
        
        decoder_l2 = self.upConv_1(decoder_l3) 
        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) 
        decoder_l2 = self.decoder_1(decoder_l2)  
        
        decoder_l2 = self.dropout(decoder_l2)
        
        decoder_l1 = self.upConv_2(decoder_l2) 
        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) 
        decoder_l1 = self.decoder_2(decoder_l1)  
        
        decoder_l1 = self.dropout(decoder_l1)
        
        decoder_l0 = self.upConv_3(decoder_l1) 
        decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1) 
        decoder_l0 = self.decoder_3(decoder_l0)
        
        decoder_last = self.upConv_extra(decoder_l0)
        decoder_last = self.decoder_extra(decoder_last)
                
        return self.last_conv(decoder_last)
    
class RARP_NVB_ROI_Mask_Unet(L.LightningModule):
    def __init__(self,*args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.model = UNet_RN18(in_channels=3, out_channels=1)
               
        self.lr = 1E-4
        self.Lambda_L1 = None
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.train_IoU = torchmetrics.classification.BinaryJaccardIndex()
        self.val_IoU = torchmetrics.classification.BinaryJaccardIndex()
                
    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch, val_step:bool = True):
        img, mask = batch
        
        mask = mask.float()
        mask = mask.unsqueeze(1)
        prediction = self(img)
                
        loss = self.lossFN(prediction, mask)
        
        predicted_labels = torch.sigmoid(prediction)
        
        if not val_step:
            if self.Lambda_L1 is not None:
                loss_l1 = 0
                for name, params in self.model.named_parameters():
                    if "decoder" in name or "upConv" in name: 
                        loss_l1 += torch.norm(params, p=1)
                    
                loss += self.Lambda_L1 * loss_l1
            
        return loss, mask, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, False)

        self.train_IoU.update(predicted_labels, true_labels)
        
        self.log("train_loss", loss, on_epoch=True)
        self.log("train_acc_IoU", self.train_IoU, on_epoch=True, on_step=False)

        return loss
    
    def on_after_backward(self):
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        self.log("grad_norm", total_norm)
        
        if total_norm < 1e-8:
            self.log("grad_warning", "Vanishing gradient suspected!")
    
    def on_train_epoch_start(self):
        for parms in self.model.encoder_base.parameters():
            parms.requires_grad = (self.current_epoch % 2 == 0)
            
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        
        self.val_IoU.update(predicted_labels, true_labels)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_acc_IoU", self.val_IoU, on_epoch=True, on_step=False)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) 
                
        return [optimizer]

class RARP_NVB_ResNet50_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.resnet50()
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) 
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        pred = self.model.fc(Cont_Net)
        
        return pred, featureMap

class RARP_NVB_VAN_CAM(L.LightningModule): #TODO
    def __init__(self) -> None:
        super().__init__()
        
        self.model = van.van_b2(pretrained = True)
        tempFC_ft = self.model.head.in_features
        self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
        
    def forward(self, data, label:torch.Tensor):
        
        cams = GradCAM(self.model, self.model.block4[-1])
        featureMap = cams.generate_cam(data, label)
              
        pred = self.model(featureMap)
        
        
        return pred, featureMap
    
class RARP_NVB_ResNet18_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.resnet18()
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) 
        Cont_Net = torch.flatten(Cont_Net)
        
        pred = self.model.fc(Cont_Net)
        
        return pred, featureMap
    
class RARP_NVB_MobileNetV2_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.mobilenet_v2()
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

        self.feature_map = self.model.features
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        pred = self.model.classifier(Cont_Net)
        
        return pred, featureMap
    
class RARP_NVB_EfficientNetV2_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

        self.feature_map = self.model.features
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        pred = self.model.classifier(Cont_Net)
        
        return pred, featureMap

class RARP_NVB_Model_BCEWithLogitsLoss(L.LightningModule):
    def __init__(self, x=None) -> None:
        super().__init__()
        
        self.model = None
        
        self.lossFN = torch.nn.BCEWithLogitsLoss() #pos_weight=torch.tensor([2.73])
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1Score = torchmetrics.F1Score('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
    def forward(self, data):
        data = data.float()
        return self.model(data)
    
    def _shared_step(self, batch):
        img, label = batch
        label = label.float()
        pred = self.forward(img).flatten()
        loss = self.lossFN(pred, label)
        
        predicted_labels = torch.sigmoid(pred)
        
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        
        self.log("train_loss", loss)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss)
        self.val_acc(predicted_labels, true_labels)
        self.f1Score(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc(predicted_labels, true_labels)
        self.f1ScoreTest(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4) 
        return [optimizer]

class RARP_NVB_FOCAL_loss(torch.nn.Module):
    def __init__(self, alpha: float = 0.25, gamma: float = 2, reduction: str = "mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction 
        
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return torchvision.ops.focal_loss.sigmoid_focal_loss(input, target, self.alpha, self.gamma, self.reduction)

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Convert targets to one-hot encoding
        targets = torch.nn.functional.one_hot(targets, num_classes=inputs.size(1)).float()

        # Compute softmax over the inputs
        probs = torch.nn.functional.softmax(inputs, dim=1)
        log_probs = torch.nn.functional.log_softmax(inputs, dim=1)

        # Compute the focal loss components
        focal_weight = (1 - probs) ** self.gamma
        loss = -self.alpha * focal_weight * targets * log_probs

        # Apply reduction method
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
#TODO
class RARP_NVB_MultiClassModel(L.LightningModule):
    def __init__(self, 
                 InitWeight = torch.tensor([1,1]),  
                 schedulerLR:bool=False, 
                 lr:float = 1e-4,
                 Model:torch.nn.Module = None,
                 Num_Classes:int = 2,
                 L1:float = 1.31E-04,
                 L2:float = 0
        ) -> None:
        super().__init__()
        
        self.model = Model
        self.lossFN = FocalLoss() #torch.nn.CrossEntropyLoss()
        self.InitWeight = InitWeight
        self.scheduler = schedulerLR
        self.lr = lr
        self.Lambda_L1 = L1
        self.Lambda_L2 = L2
        self.num_classes = Num_Classes

        self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
        self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
        self.test_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
        self.f1ScoreTest = torchmetrics.F1Score("multiclass", num_classes=Num_Classes)
        
        self.val_loop = False

    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch):
        img, label = batch
        prediction = self(img) 
        predicted_labels = torch.softmax(prediction, dim=1)
        loss = self.lossFN(prediction, label)
        
        if self.Lambda_L1 is not None and not self.val_loop:
            loss_l1 = 0
            for params in self.model.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
       
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        self.val_loop = False
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        self.val_loop = True
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)

    def test_step(self, batch, batch_idx):
        self.val_loop = True
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2) 
        if self.scheduler:
            #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        else:
            return [optimizer]

class RARP_NVB_Model(L.LightningModule):
    def __init__(self, 
                 InitWeight = torch.tensor([1,1]), 
                 typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy, 
                 schedulerLR:bool=True, 
                 lr:float = 1e-4
                ) -> None:
        super().__init__()

        self.model = None
        self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
        self.InitWeight = InitWeight
        self.scheduler = schedulerLR
        self.lr = lr
        self.Lambda_L1 = None #1.31E-04 #
        self.Lambda_L2 = 0
        
        print (f"LR= {self.lr}, L1= {self.Lambda_L1}")

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        #self.f1Score = torchmetrics.F1Score('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')

    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch):
        img, label = batch
        if self.InitWeight is not None:
            self.lossFN.weight = self.InitWeight[label]
        
        label = label.float()
        prediction = self(img)[:,0] #.flatten()
        loss = self.lossFN(prediction, label)
        
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
        
        predicted_labels = torch.sigmoid(prediction)
        
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        #self.f1Score.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("hp_metric", self.val_acc, on_step=False, on_epoch=True)
        #self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2) 
        if self.scheduler:
            #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        else:
            return optimizer
    
class RARP_Ensemble(L.LightningModule):
    def __init__(self, ListModels,  
                 InitWeight = torch.tensor([1,1]), 
                 typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy, 
                 schedulerLR:bool=False, 
                 lr:float = 1e-4) -> None:
        super().__init__()
        
        self.ListModels = ListModels
        #for m in self.ListModels:
        #    m.freeze()
        input_p = len(self.ListModels)
        
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_p, out_features=128),
            torch.nn.SiLU(),
            #torch.nn.Dropout(0.1), #0.5
            torch.nn.Linear(128, 1),
            torch.nn.Sigmoid()
        )
        
        self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCELoss(reduction="sum")
        self.InitWeight = InitWeight
        self.scheduler = schedulerLR
        self.lr = lr

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1Score = torchmetrics.F1Score('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
    def forward(self, data):
        data = data.float()
        p = [m(data) for m in self.ListModels]
        p = torch.cat(p, dim=1)
        x = self.classifier(p)
        return x
    
    def _shared_step(self, batch):
        img, label = batch
        if self.InitWeight is not None:
            self.lossFN.weight = self.InitWeight[label]
        
        label = label.float()
        prediction = self(img)[:,0] #.flatten()
        loss = self.lossFN(prediction, label)
        predicted_labels = prediction
        
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.f1Score.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 
        if self.scheduler:
            #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, 1e-8, verbose=True)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        else:
            return [optimizer]
    
class RARP_NVB_ResNet18(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

class RARP_NVB_DaVit(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = timm.create_model("davit_small.msft_in1k", pretrained=True, num_classes=1)   

class RARP_NVB_EfficientNetV2_Deep(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
        
        self.model.classifier.append(torch.nn.SiLU(True))
        self.model.classifier.append(torch.nn.Linear(128, 8))
        self.model.classifier.append(torch.nn.SiLU(True))
        self.model.classifier.append(torch.nn.Linear(8, 1))

class RARP_NVB_ResNet50_Deep(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features
        self.model.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features=tempFC_ft, out_features=128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Linear(8, 1)
        )
        
    def forward(self, img):
        img = img.float()
        pred = self.model(img)
        return pred
    
class RARP_NVB_ResNet50_Deep_OPTuna(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, output_dims=[128, 8], dropuot=0.2, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        layers = []
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
                
        inputDim = self.model.fc.in_features
        
        for outputDim in output_dims:
            layers.append(torch.nn.Linear(inputDim, outputDim))
            layers.append(torch.nn.SiLU(True))
            layers.append(torch.nn.Dropout(dropuot))
            inputDim = outputDim
        
        layers.append(torch.nn.Linear(inputDim, 1))
        
        self.model.fc = torch.nn.Sequential(*layers)
        
    def forward(self, img):
        img = img.float()
        pred = self.model(img)
        return pred

class RARP_NVB_ResNet50_V2(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features
        self.model.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features=tempFC_ft, out_features=128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8)
        )
        self.model.fc2 = torch.nn.Linear(8 + InputNeurons, 1)
        
    def forward(self, data):
        img, extra = data
        img = img.float()
        x = torch.nn.functional.silu(self.model(img), True)
        extradata = torch.concat((x, extra), dim=1)
        pred = self.model.fc2(extradata)
        
        return pred
    
class RARP_NVB_ResNet50_V3(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
                
        self.model.extraFC = torch.nn.Linear(InputNeurons, 128)
        
        self.model.fc2 = torch.nn.Linear(256, 1)
        
    def forward(self, data):
        img, extra = data
        img = img.float()
        
        x = torch.nn.functional.silu(self.model(img), True)
        y = torch.nn.functional.silu(self.model.extraFC(extra), True)
        x = torch.concat((x, y), dim=1)
        
        pred = self.model.fc2(x)
                
        return pred

class RARP_NVB_ResNet50_V1(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features + InputNeurons
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
                
        self.Decoder = torch.nn.Sequential(*list(self.model.children())[:-2])
        
    def forward(self, data):
        img, extra = data
        img = img.float()
        featureMap = self.Decoder(img)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) 
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        Cont_Net = torch.concat((Cont_Net, extra), dim=1)
        
        pred = self.model.fc(Cont_Net)
                
        return pred
        
class RARP_NVB_VAN(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = van.van_b2(pretrained = True)
        tempFC_ft = self.model.head.in_features
        self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)        
        
class RARP_NVB_ResNet50(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

class RARP_NVB_MLP(torch.nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=512, bottleneck=256, n_layers=3, norm_last_layer=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.activationFN = torch.nn.GELU()
        
        if n_layers == 1:
            self.mlp = torch.nn.Linear(in_dim, bottleneck)
        else:
            layers = [torch.nn.Linear(in_dim, hidden_dim)]
            layers.append(self.activationFN)
            layers.append(torch.nn.Dropout(0.30))
            for _ in range(n_layers - 2):
                layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
                layers.append(self.activationFN)
            layers.append(torch.nn.Linear(hidden_dim, bottleneck))
            self.mlp = torch.nn.Sequential(*layers)
            
        self.apply(self._init_weights)
        self.last_layer = torch.nn.utils.weight_norm(
            torch.nn.Linear(bottleneck, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False
            
    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = self.mlp(x)
        x = torch.nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        
        return x

class RARP_NVB_DINO_Wrapper(torch.nn.Module):
    def __init__(self, backbone:torch.nn.Module, new_head:torch.nn.Module, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.backbone = backbone
        self.head = new_head
        
    def forward(self, x):
        if isinstance(x, list):
            n_crops = len(x)
            concatCrops = torch.cat(x, dim=0)
        else:
            concatCrops = x
            n_crops = 1
        embedding = self.backbone(concatCrops)
        logitis = self.head(embedding)
        chunks = logitis.chunk(n_crops)
        
        return chunks
    
class RARP_NVB_DINO_Loss(torch.nn.Module):
    def __init__(self, out_dim:int, teacher_Thao:float = 0.04, student_Thao:float = 0.1, center_momentum:float = 0.9, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.S_Thao = student_Thao
        self.T_Thao = teacher_Thao
        self.C_Momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))
        
    def forward(self, s_Output, t_Output, validation_step:bool = False):
        sTemp = [s / self.S_Thao for s in s_Output]
        tTemp = [(t - self.center) / self.T_Thao for t in t_Output]
        
        studentSM = [torch.nn.functional.log_softmax(s, dim=-1) for s in sTemp]
        teacherSM = [torch.nn.functional.softmax(t, dim=-1).detach() for t in tTemp]
        
        total_loss = 0
        n_loss_terms = 0
        
        for t_ix, t in enumerate(teacherSM):
            for s_ix, s in enumerate(studentSM):
                if (t_ix == s_ix) and (len(teacherSM) > 1):
                    continue
                
                loss = torch.sum(-t * s, dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        
        total_loss /= n_loss_terms
        
        if not validation_step:
            self.update_center(t_Output)
        
        return total_loss
    
    @torch.no_grad()
    def update_center(self, t_output):
        b = torch.cat(t_output).mean(dim=0, keepdim=True)
        self.center = self.center * self.C_Momentum + b * (1 - self.C_Momentum)
        
class RARP_NVB_DINO_RestNet50_Deep(L.LightningModule):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        TypeLoss=TypeLossFunction.CrossEntropy,
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        L1:float = None,
        L2:float = 0,
    ) -> None:
        super().__init__()
    
        self.lr =  lr
        self.Lambda_L1 = L1
        self.Lambda_L2 = L2
        self.threshold = threshold
        self.momentum_teacher = momentum_teacher
        self.out_dim = 512
        self.in_dim = 2048

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
        self.teacher_Labels = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
        self.student = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) #torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
        self.teacher_Features = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        
        self.student.fc = torch.nn.Identity()
        self.teacher_Features.fc = torch.nn.Identity()
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        for parms in self.teacher_Labels.model.parameters():
            parms.requires_grad = False
            
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False
        
        self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, 0.04, 0.1, momentum_teacher)
        self.lossFN_KD = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
        
        #self.lossFH_KLDiv = torch.nn.KLDivLoss(reduction="batchmean")
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(self.out_dim, 128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Linear(8, 1)
        )
        
    def forward(self, data, val_step:bool = True):
        if val_step:
            data = data.float()
            dataClassificator, dataTeacher, dataStudent = data, data, data
        else:
            data = [d.float() for d in data]
            dataClassificator, dataTeacher, dataStudent = data[0], data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        TeacherLabels = self.teacher_Labels(dataClassificator)
        Student = self.student(dataStudent)
        
        # es se evaluan todas las salidas del estuidaitne
        #if isinstance(dataStudent, list):
        #    #index = np.random.randint(0, len(dataStudent))
        #    temp = self.student(dataStudent)
        #    CatS_Classifier = torch.cat(temp, dim=0)
        #    meanS_Classifier = torch.zeros(self.in_dim)
        #    for dataS in temp:
        #        meanS_Classifier += dataS
        #    #S_Classifier = self.student(dataStudent[index])
        #    S_Classifier = meanS_Classifier / len(dataStudent)
        #else:
        #    S_Classifier = Student
            
        Cont_Net = torch.cat(Student, dim=0)
                
        pred = self.clasiffier(Cont_Net)
        
        if not val_step:
            TeacherLabels = [self.teacher_Labels(dataClassificator) for _ in range(len(dataStudent))]
            TeacherLabels = torch.cat(TeacherLabels, dim=0)
        
        TeacherLabelsPred = torch.sigmoid(TeacherLabels.flatten())
        PseudoLabels = (TeacherLabelsPred > self.threshold) * 1.0
                        
        return (pred.flatten(), PseudoLabels, TeacherLabels.flatten()), (TeacherDino, Student)
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        if not val_step:
            label = torch.cat([label for _ in range(len(img))], dim=0)
        
        label = label.float()
        KD_Prediction, DINO_Loss = self(img, val_step)
        TeacherF, StudentF = DINO_Loss
        prediction, PseudoLabels, teacherOutputs = KD_Prediction
        
        predicted_labels = torch.sigmoid(prediction)
        
        ##verstion 1
        W_Alpha, W_Beta = (1, 0.5)#(1, 0.5)
        loss = W_Alpha * self.lossFN_KD(prediction, PseudoLabels) + W_Beta * self.lossFN_KD(prediction, label)
        
        #version 2
        #thao_KD = 1#5.0
        #W_Alpha, W_Beta = (0.6, 0.4)
        
        #softTeacher = torch.sigmoid(teacherOutputs/thao_KD)
        #softStudent = torch.sigmoid(prediction/thao_KD)
        
        #loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
        #loss_hl = self.lossFN_KD(prediction, label)
        
        #loss = W_Alpha * loss_hl + W_Beta * loss_sl
        
        #loss = W_Alpha * self.lossFN_KD(prediction, label) + W_Beta * (self.lossFH_KLDiv(softStudent, softTeacher) * (thao_KD ** 2))
        loss += (self.lossFN_DINO(StudentF, TeacherF) if not val_step else 0)
        
        if not val_step:
            self.logger.experiment.add_histogram ("Teacher", TeacherF[0])
            self.logger.experiment.add_histogram ("Student", StudentF[1])
                
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.student.parameters(): # aqui
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
        
        #return loss, PseudoLabels, predicted_labels
        return loss, label, predicted_labels
        
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, False)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        with torch.no_grad():
            for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
                teacher_ps.data.mul_(self.momentum_teacher)
                teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
            
            self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
            
                
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr, weight_decay=self.Lambda_L2) 
        
        return [optimizer]

class RARP_NVB_DINO_VAN(RARP_NVB_DINO_RestNet50_Deep):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher: float = 0.9995, 
        lr: float = 0.0001, 
        L1: float = None, 
        L2: float = 0
    ) -> None:
        super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
        
        self.in_dim = 512
        
        self.student = van.van_b2(pretrained = True, num_classes = -1) 
        self.teacher_Features = van.van_b2(pretrained = True, num_classes = -1) 
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False

class RARP_NVB_DINO_ViT(RARP_NVB_DINO_RestNet50_Deep):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher: float = 0.9995, 
        lr: float = 0.0001, 
        L1: float = None, 
        L2: float = 0
    ) -> None:
        super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
        
        self.in_dim = 768
        
        self.student = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        self.teacher_Features = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        
        self.student.heads = torch.nn.Identity()
        self.teacher_Features.heads = torch.nn.Identity()
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False

class RARP_NVB_Classification_Head(torch.nn.Module):
    def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.activation = activation_fn
        
        if len (layer) == 0:        
            self.head = torch.nn.Linear(in_features, out_features)
        else:
            temp_head = []
            next_input = in_features
            for num in layer:
                temp_head.append(torch.nn.Linear(next_input, num))
                temp_head.append(self.activation)
                temp_head.append(torch.nn.Dropout(0.4))
                next_input = num
                
            temp_head[-1] = torch.nn.Dropout(0.2)
            temp_head.append(torch.nn.Linear(next_input, out_features))
            
            self.head = torch.nn.Sequential(*temp_head)
            del temp_head
    
    def forward(self, x):
        return self.head(x)

class RARP_Encoder_DINO(L.LightningModule):
    def __init__(self, 
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        Teacher_T:float = 0.04,
        Student_T:float = 0.1,
        max_epochs:int = 100,
        total_steps:int = None
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        
        self.lr =  lr
        self.momentum_teacher = momentum_teacher
        self.out_dim = 65536
        self.in_dim = 512
        
        self.student = van.van_b1(num_classes = 0)
        self.teacher = van.van_b1(num_classes = 0)
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True)
        )
        
        self.teacher = RARP_NVB_DINO_Wrapper(
            self.teacher,
            RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True)
        )
        
        self.teacher.load_state_dict(self.student.state_dict())
        
        for parms in self.teacher.parameters():
            parms.requires_grad = False
            
        self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
        
    def forward(self, data, val_step=False):
        if val_step:
            dataTeacher, dataStudent = data.float(), data.float()
        else:
            data = [d.float() for d in data]
            dataTeacher, dataStudent = data[0:3], data
        
        teacher_features = self.teacher(dataTeacher)
        student_features = self.student(dataStudent)
        
        return teacher_features, student_features
    
    def _shared_step(self, batch, val_step=False):
        img, _ = batch
        t, s = self(img, val_step)
        
        loss_Dino = self.lossFN_DINO(s, t, validation_step=val_step)
        
        return loss_Dino
        
    def training_step(self, batch, batch_idx):
        loss = self._shared_step(batch, False)    
        
        self.log("train_loss", loss, on_epoch=True, sync_dist=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self._shared_step(batch, True)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
    
    def on_train_batch_end(self, outputs, batch, batch_idx):  
        #step = self.global_step
        #
        #m = 1.0 - (1.0 - self.hparams.momentum_teacher) * (
        #    (1 + math.cos(math.pi * step / self.hparamstotal_steps)) / 2
        #)
              
        with torch.no_grad():
            for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher.parameters()):
                teacher_ps.data.mul_(self.momentum_teacher)
                teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
                
    def on_after_backward(self):
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        self.log("grad_norm", total_norm)
        
        if total_norm < 1e-8:
            self.log("grad_warning", "Vanishing gradient suspected!")
            
    def configure_optimizers(self):        
        optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr)
        
        #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        #    optimizer,
        #    T_max=self.hparams.max_epochs,  # decays from epoch 0 → epoch max_epochs
        #    eta_min=0.0
        #)
        
        #return {
        #    "optimizer": optimizer,
        #    "lr_scheduler": {
        #        "scheduler": scheduler,
        #        "interval": "epoch",    # <-- step once per epoch
        #        "frequency": 1,
        #    },
        #}
        
        return [optimizer]



class RARP_NVB_DINO_MultiTask(L.LightningModule):
    # Define a hook function to capture the output
    def _hook_fn_Student(self, module, input, output):
        self.last_conv_output_S = output
        
    def _hook_fn_Teacher(self, module, input, output):
        self.last_conv_output_T = output
    
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy,
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        L1:float = None,
        L2:float = 0,
        std: float = None,
        mean: float = None,
        SoftAdptAlgo:int = 0,
        SoftAdptBeta:float = 0.1,
        Teacher_T:float = 0.04,
        Student_T:float = 0.1,
        intermittent:bool = False
    ) -> None:
        super().__init__()
        
        self.intermittent_train = intermittent
        
        self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
        self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
    
        self.lr =  lr
        self.Lambda_L1 = L1
        self.Lambda_L2 = L2
        self.Teacher_t = Teacher_T
        self.Studet_T = Student_T
        self.momentum_teacher = momentum_teacher
        self.out_dim = 1024
        self.in_dim = 512
        self.weights = torch.tensor([1,1,1])
        
        self.softAdapt = NormalizedSoftAdapt(SoftAdptBeta) if SoftAdptAlgo == 1 else LossWeightedSoftAdapt(SoftAdptBeta)
        self.loss_history = {
            'loss_DINO': [], 
            'loss_Reconstruction': [],
            'loss_Binary': [],
        }

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
        self.student = van.van_b2(pretrained = True, num_classes = 0)
        self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0)
        
        self.decoder = DynamicDecoder(input_channels=1024) 
               
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
        )
        
        self.teacher_Features.load_state_dict(self.student.state_dict())
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False
        
        self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
        self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
        self.ReconstructionLoss = torch.nn.MSELoss()
        
        self.last_conv_output_T = None
        self.last_conv_output_S = None
        
        self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher)
        self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(1024, 128),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.4),
            
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.2),
            
            torch.nn.Linear(8, 1)
        )
                
        print(f"lr={self.lr}, L1={self.Lambda_L1}")
        
    def _denormalize(self, tensor:torch.Tensor):
        # Move mean and std to the same device as the input tensor
        mean = self.mean_IMG.to(tensor.device)
        std = self.std_IMG.to(tensor.device)
        return tensor * std + mean
    
    def _calc_L1(self, params):
        l1 = 0
        for p in params:
            l1 += torch.sum(torch.abs(p))
        return self.Lambda_L1 * l1
    
    def _calc_weights(self, log_weights:bool = True):
        self.weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]),
            torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_Reconstruction"]),
            torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
            verbose=False
        )

        if log_weights:
            self.log("W_loss_img", self.weights[1], on_epoch=True, on_step=False)
            self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_GT", self.weights[2], on_epoch=True, on_step=False)

        self.loss_history = {
            'loss_DINO': [], 
            'loss_Reconstruction': [],
            'loss_Binary': [],
        }
        
    def forward(self, data, val_step:bool = True):
        if val_step:
            data = data.float()
            dataTeacher, dataStudent = data, data
        else:
            data = [d.float() for d in data]
            dataTeacher, dataStudent = data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        Student = self.student(dataStudent)
                       
        if not val_step:
            NumChunks = len(dataStudent)
            S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
            self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
        
        cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
        
        reconstructed_image = self.decoder(cat_features)
        
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
        pred = self.clasiffier(Cont_Net)
                        
        return pred, (Student, TeacherDino), reconstructed_image
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        prediction, features, new_image = self(img, val_step)
        StudentF, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, (NOAH, RARP_NVB_Classification_Head)):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_DINO"].append(loss_Dino.item())
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
        
        return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
        
    def on_train_epoch_start(self):
        if self.current_epoch % 2 == 0 and self.current_epoch != 0:
            self._calc_weights()
            
        if self.intermittent_train and self.current_epoch != 0:
            
            par_epoch = (self.current_epoch % 2 == 0)
            
            for parms in self.student.backbone.parameters():
                parms.requires_grad = par_epoch
                
            for parms in self.decoder.parameters():
                parms.requires_grad = not par_epoch
                
            for parms in self.clasiffier.parameters():
                parms.requires_grad = not par_epoch
                
    
    def training_step(self, batch, batch_idx):
        #if self.global_step != 0 and self.global_step % 100 == 0:
        #    self._calc_weights()
            
        loss, true_labels, predicted_labels, losses = self._shared_step(batch, False)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("train_loss_DINO", losses[0], on_epoch=True, on_step=False)
        self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
        
        if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)

        return loss
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        with torch.no_grad():
            for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
                teacher_ps.data.mul_(self.momentum_teacher)
                teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
            
            #self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
    
    def on_after_backward(self):
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        self.log("grad_norm", total_norm)
        
        if total_norm < 1e-8:
            self.log("grad_warning", "Vanishing gradient suspected!")
                
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("val_loss_DINO", losses[0], on_epoch=True, on_step=False)
        self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, losses = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
        if self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            
            imgOrig = torch.clip(self._denormalize(batch[0])/255, 0, 1)
            imgOrig = imgOrig[:, [2, 1, 0], :, :]
            
            imgReconstruction = torch.cat((imgOrig, imgReconstruction), dim=0)
            
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)  #, weight_decay=self.Lambda_L2
        
        return [optimizer]
    
class RARP_Hybrid_TS_LR(torch.nn.Module):
    def __init__(
        self, 
        base_TS_Model:str = "",
        std: float = None,
        mean: float = None,
        stretch: bool = False,
        masked: bool = False
    ):
        super().__init__()
        self.mean = mean
        self.std = std
        self.stretch = stretch 
        self.masked = masked
                
        self.mid_length = 0
        
        self.labels = ["L", "R"]
        
        self.baseModel = RARP_NVB_DINO_MultiTask.load_from_checkpoint(base_TS_Model)
        self.baseModel.eval()
        
    def _mask_LR(self, image:torch.Tensor, Left:bool= True):
        halfImg = image[:, :, :, :self.mid_length] if Left else image[:, :, :, self.mid_length:]
        pad_zeros = torch.zeros_like(halfImg) #Agv. Color 
        listImgs = [halfImg, pad_zeros] if Left else [pad_zeros, halfImg]
        return torch.cat(listImgs, dim=-1)    
    
    def _crop_LR(self, image:torch.Tensor, Left:bool = True):    
        if Left:
            return image[:, :, :, :self.mid_length] if not self.stretch else torch.nn.functional.interpolate(
                image[:, :, :, :self.mid_length],
                size=(224, 224),
                mode='bicubic',
                align_corners=False
            )
        else:
            return image[:, :, :, self.mid_length:] if not self.stretch else torch.nn.functional.interpolate(
                image[:, :, :, self.mid_length:],
                size=(224, 224),
                mode='bicubic',
                align_corners=False
            )
        
    def forward(self, x):
        _, _, _, w = x.shape #[B, C, H, W]
        self.mid_length = w // 2
        
        LR_Img = {
            "L":self._crop_LR(x, True) if not self.masked else self._mask_LR(x, True),
            "R":self._crop_LR(x, False) if not self.masked else self._mask_LR(x, False)
        }
        
        pred = []        
        for label in self.labels:
            with torch.no_grad():
                raw_pred, _, _ = self.baseModel(LR_Img[label])
                pred.append(raw_pred)
        
        return torch.cat(pred, dim=-1)
                

#Ablation Models
"""T-S Multi-task model With out Recostruccion (V3R1_A1)

Returns:
    LightningModule
"""
class RARP_NVB_DINO_MultiTask_A1(RARP_NVB_DINO_MultiTask):
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
        
        self.decoder = torch.nn.Identity()
        del self.ReconstructionLoss
        
        self.mean_IMG = None
        self.std_IMG = None
        
        self.weights = torch.tensor([1,1])
        self.loss_history = {
            'loss_DINO': [], 
            'loss_Binary': [],
        }
    
    def _calc_weights(self, log_weights:bool = True):
        self.weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]),
            torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
            verbose=False
        )

        if log_weights:
            self.log("W_loss_img", 0, on_epoch=True, on_step=False)
            self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False)

        self.loss_history = {
            'loss_DINO': [], 
            'loss_Binary': [],
        }
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        prediction, features, new_image = self(img, val_step)
        StudentF, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, NOAH):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        #orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        #loss_img = self.ReconstructionLoss(new_image, orignalImg)
        #loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_DINO"].append(loss_Dino.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_Dino + self.weights[1] * loss_HL
        
        return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[1] * loss_HL, 0, new_image)
        
"""T-S Multi-task model With out SoftAdadapt, Fix loss wegth 0.333_  (V3R1_A2)

Returns:
    LightningModule
"""
class RARP_NVB_DINO_MultiTask_A2(RARP_NVB_DINO_MultiTask):
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
        
        del self.softAdapt         
        self.weights = [1/3, 1/3, 1/3]
        
    def _calc_weights(self):
        self.weights = [1/3, 1/3, 1/3]

        self.loss_history = {
            'loss_DINO': [], 
            'loss_Reconstruction': [],
            'loss_Binary': [],
        }

"""S Multi-task model No Dino base encoder VAN_b2, (V3R1_A3_1)

Returns:
    LightningModule
"""
class RARP_NVB_DINO_MultiTask_A3(RARP_NVB_DINO_MultiTask):
    def _hook_fn_Student(self, module, input, output):
        self.last_conv_output_S = output
        self.last_conv_output_T = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
        if not self.val_phace:
            self.last_conv_output_T = self.last_conv_output_T[:16] # Fixed bach of 8
        
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
               
        self.student = RARP_NVB_DINO_Wrapper(
            van.van_b2(pretrained = True, num_classes = 0),
            torch.nn.Identity()
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            torch.nn.Identity(),
            torch.nn.Identity()
        )
        
        self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
        
        self.val_phace = True
                
        self.weights = torch.tensor([1,1])
        self.loss_history = {
            'loss_Reconstruction': [], 
            'loss_Binary': [],
        }
    
    def _calc_weights(self, log_weights:bool = True):
        self.weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_DINO"]),
            torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
            verbose=False
        )

        if log_weights:
            self.log("W_loss_img", self.weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_DINO", 0, on_epoch=True, on_step=False)
            self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False)

        self.loss_history = {
            'loss_Reconstruction': [], 
            'loss_Binary': [],
        }
        
    def _shared_step(self, batch, val_step:bool = False):
        self.val_phace = val_step
        
        img, label = batch
                
        prediction, features, new_image = self(img, val_step)
        _, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, NOAH):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        #loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_img + self.weights[1] * loss_HL
        
        return loss, label, predicted_labels, (0, self.weights[1] * loss_HL, self.weights[0] * loss_img, new_image)
        
"""S Multi-task model No Dino base encoder RN50, (V3R1_A3_2)

Returns:
    LightningModule
"""
class RARP_NVB_DINO_MultiTask_A3_RN50(RARP_NVB_DINO_MultiTask):
    def _hook_fn_Student(self, module, input, output):
        self.last_conv_output_S = output
        self.last_conv_output_T = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
        if not self.val_phace:
            self.last_conv_output_T = self.last_conv_output_T[:16] # Fixed bach of 8
            
        self.last_conv_output_T = self.last_conv_output_T[:, :0, :, :]
        
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
               
        self.student = RARP_NVB_DINO_Wrapper(
            torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT),
            torch.nn.Identity()
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            torch.nn.Identity(),
            torch.nn.Identity()
        )
        
        self.student.backbone.layer4.register_forward_hook(self._hook_fn_Student)
        
        self.val_phace = True
                
        self.weights = torch.tensor([1,1])
        self.loss_history = {
            'loss_Reconstruction': [], 
            'loss_Binary': [],
        }
        
        self.decoder = DynamicDecoder(input_channels=2048)
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(2048, 128),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.4),
            
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.2),
            
            torch.nn.Linear(8, 1)
        )
    
    def _calc_weights(self, log_weights:bool = True):
        self.weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_DINO"]),
            torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
            verbose=False
        )

        if log_weights:
            self.log("W_loss_img", self.weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_DINO", 0, on_epoch=True, on_step=False)
            self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False)

        self.loss_history = {
            'loss_Reconstruction': [], 
            'loss_Binary': [],
        }
        
    def _shared_step(self, batch, val_step:bool = False):
        self.val_phace = val_step
        
        img, label = batch
                
        prediction, features, new_image = self(img, val_step)
        _, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, NOAH):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        #loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_img + self.weights[1] * loss_HL
        
        return loss, label, predicted_labels, (0, self.weights[1] * loss_HL, self.weights[0] * loss_img, new_image)

"""T-S Multi-task model, classification head layer change, (V3R1_A4)

Returns:
    LightningModule
"""       
class RARP_NVB_DINO_MultiTask_A4(RARP_NVB_DINO_MultiTask):
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
        #layeres = [128, 8] #L3 original
        #layeres = [256, 128, 8] #L4
        #layeres = [8] #L2
        layeres = [] #L1
        self.clasiffier = RARP_NVB_Classification_Head(1024, 1, layeres, torch.nn.SiLU(True))
 
class DWConvBlock(torch.nn.Module):
    def __init__(self, in_c, out_c, p_drop=0.1):
        super().__init__()
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(in_c, in_c, 3, padding=1, groups=in_c, bias=False),  # depthwise
            torch.nn.BatchNorm2d(in_c),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_c, out_c, 1, bias=False),                         # pointwise
            torch.nn.BatchNorm2d(out_c),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout2d(p_drop) if p_drop > 0 else torch.nn.Identity(),
        )
        
    def forward(self, x): 
        return self.block(x) 

class GatedSkip(torch.nn.Module):
    """Learnable scalar gate α∈[0,1] per skip; set init_alpha<1 to avoid over-reliance early on."""
    def __init__(self, init_alpha=0.25):
        super().__init__()
        self.alpha = torch.nn.Parameter(torch.tensor(init_alpha).float())
    def forward(self, x, skip):
        a = torch.sigmoid(self.alpha)
        return x + a * skip

class ReconDecoderLite(torch.nn.Module):
    """
    Expect encoder feature maps (low→high res): [C2, C3, C4, C5]
    Example shapes you shared earlier:
      C2:[B,  64,56,56], C3:[B,128,28,28], C4:[B,320,14,14], C5:[B,512 or 1024,7,7]
    Configure in_channels accordingly.
    """
    def __init__(self, in_channels=[64,128,320,512], base=96, out_channels=3,
                 p_drop=0.1, use_skips=(True, True, False, False)):
        super().__init__()
        c2,c3,c4,c5 = in_channels
        # Project encoder channels down to a uniform width to keep params small
        self.proj5 = torch.nn.Conv2d(c5, base*4, 1, bias=False)
        self.proj4 = torch.nn.Conv2d(c4, base*4, 1, bias=False)
        self.proj3 = torch.nn.Conv2d(c3, base*2, 1, bias=False)
        self.proj2 = torch.nn.Conv2d(c2, base,   1, bias=False)

        # Up stages: bilinear upsample → DWConvBlock
        self.up54 = DWConvBlock(base*4, base*4, p_drop)
        self.up43 = DWConvBlock(base*4, base*2, p_drop)
        self.up32 = DWConvBlock(base*2, base,   p_drop)
        self.up21 = DWConvBlock(base,   base//2, p_drop)

        # Optional gated skips (kept light to fight overfitting)
        self.use_skips = use_skips
        self.g54 = GatedSkip(0.15) if use_skips[3] else None  # C5→C4 (usually False)
        self.g43 = GatedSkip(0.25) if use_skips[2] else None
        self.g32 = GatedSkip(0.35) if use_skips[1] else None
        self.g21 = GatedSkip(0.50) if use_skips[0] else None
        
        self.head = torch.nn.Sequential(
            torch.nn.Conv2d(base//2, base//2, 3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(base//2, out_channels, 1)
        )

    def forward(self, feats):
        c2, c3, c4, c5 = feats  # low→high resolution
        p5 = self.proj5(c5)                        # 7×7
        x  = self.up54(p5)                         # process at 7×7

        x  = torch.nn.functional.interpolate(x, size=c4.shape[-2:], mode='nearest-exact')
        p4 = self.proj4(c4)
        x  = self.g54(x, p4) if self.g54 else x    # gated skip (optional)
        x  = self.up43(x)                           # 14×14

        x  = torch.nn.functional.interpolate(x, size=c3.shape[-2:], mode='nearest-exact')
        p3 = self.proj3(c3)
        x  = self.g43(x, p3) if self.g43 else x
        x  = self.up32(x)                           # 28×28

        x  = torch.nn.functional.interpolate(x, size=c2.shape[-2:], mode='nearest-exact')
        p2 = self.proj2(c2)
        x  = self.g32(x, p2) if self.g32 else x
        x  = self.up21(x)                           # 56×56

        # final upscale to input size if needed
        out = self.head(x)                          # 56×56 → you can upsample outside if target is larger
        return out
       
class RARP_NVB_DINO_MultiTask_A5_MAE(RARP_NVB_DINO_MultiTask):
    def _encoder_hool_fn(self, module, input, output):
        self.feature_maps.append(output)
    
    def _register_encoder_hooks(self, block_list:list):
        for layer in block_list:
            self.hooks.append(layer.register_forward_hook(self._encoder_hool_fn))
    
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
        
        self.E_FREEZE = 100
        self.Lambda_L1_Decoder = None #1e-6 
                
        self.ReconstructionLoss = ReconstructionLoss(1, 1e-4)
        
        self.hooks = []
        self.feature_maps = []
        
        self.list_blocks = [
            self.student.backbone.block1[-1],
            self.student.backbone.block2[-1],
            self.student.backbone.block3[-1],
        ]
        
        self._register_encoder_hooks(self.list_blocks)
        
        self.decoder = ReconDecoderLite([64, 128, 320, 1024], use_skips=(True, True, False, False))
        
    def _is_norm(self, param_name):
        # crude but effective: catches BatchNorm/LayerNorm/GroupNorm/InstanceNorm weight/bias
        return any(k in param_name.lower() for k in ["norm", "bn", "running_", "bias"])
    
    def _calc_L1_2D(self, params:torch.nn.Module):
        l1 = 0.0
        for name, p in params.named_parameters():
            if not p.requires_grad:
                continue
            if self._is_norm(name):   # skip BN/LN/bias
                continue
            l1 = l1 + p.abs().sum()
            
        return self.Lambda_L1_Decoder * l1
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        prediction, features, new_image = self(img, val_step)
        StudentF, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, (NOAH, RARP_NVB_Classification_Head)):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float().detach() if not val_step and self.current_epoch >= self.E_FREEZE else loss_img.float()
        

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L1_Decoder is not None:
                loss_img += self._calc_L1_2D(self.decoder)
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_DINO"].append(loss_Dino.item())
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
        
        return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
    
    def on_fit_start(self):
        chpt_file = getattr(self.trainer, "ckpt_path", None)
        
        if chpt_file is not None:
            self.re_init = True
            for cb in self.trainer.callbacks:
                if isinstance(cb, callbk.EarlyStopping):
                    cb.wait_count = 0
                    cb.best_score = torch.tensor(float("inf"), device=self.device)
                    cb.stopped_epoch = 0
                    cb._should_stop = False
                    print("✅ EarlyStopping reset!")
    
    def on_train_epoch_start(self):
        if self.current_epoch % 2 == 0 and self.current_epoch != 0 and not self.re_init:
            self._calc_weights()
            
        if self.re_init:
            self.re_init = False
            print ("re-init SoftAdapt loss history")
            self.loss_history = {
                'loss_DINO': [], 
                'loss_Reconstruction': [],
                'loss_Binary': [],
            }
            print ("re-init Classifier")
            for m in self.clasiffier.modules():
                if isinstance(m, torch.nn.Linear):
                    torch.nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        torch.nn.init.zeros_(m.bias)
            
    #    freeze = self.current_epoch < self.E_FREEZE
    #    
    #    for p in self.decoder.parameters():
    #        p.requires_grad = freeze
            
    
            
    def forward(self, data, val_step:bool = True):
        self.feature_maps = []
                
        if val_step:
            data = data.float()
            dataTeacher, dataStudent = data, data
        else:
            data = [d.float() for d in data]
            dataTeacher, dataStudent = data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        Student = self.student(dataStudent)
        
        _temp = []
        NumChunks = len(dataStudent)
        num_blocks = len(self.list_blocks)
                              
        if not val_step:    
            S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
            self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
            for i in range(num_blocks):
                S_GlogalViews = torch.cat(self.feature_maps[i].chunk(NumChunks)[1:3], dim=0)
                _temp.append(S_GlogalViews)
                
            self.feature_maps = _temp
        
        cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
        
        self.feature_maps.append(cat_features)
        
        if not val_step and self.current_epoch >= self.E_FREEZE: 
            with torch.no_grad():            
                reconstructed_image = self.decoder(self.feature_maps)
        else:            
            reconstructed_image = self.decoder(self.feature_maps)
            
        reconstructed_image = torch.nn.functional.interpolate(reconstructed_image, size=(224, 224), mode="nearest-exact")
        
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
        pred = self.clasiffier(Cont_Net)
                        
        return pred, (Student, TeacherDino), reconstructed_image
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW([
            {"params": self.student.parameters(),"lr": self.lr},
            {"params": self.clasiffier.parameters(), "lr": self.lr},
            {"params": self.decoder.parameters(), "lr": self.lr},
        ])

        # Lambda per param group: keep last one (decoder) at 0 after E_FREEZE
        def lam_enc(epoch):  return 1.0
        def lam_cls(epoch):  return 1.0
        def lam_dec(epoch):  return 1.0 if epoch < self.E_FREEZE else 0.0

        sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=[lam_enc, lam_cls, lam_dec])
        return {"optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}}

        
        

class Scalar_AttnPooling(torch.nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        
        self.proj = torch.nn.Linear(1, hidden_dim)
        self.attn = torch.nn.Linear(hidden_dim, 1)
        self.pooler_cls = torch.nn.Linear(1, 1)
        
    def forward(self, logits_bt):
        x = logits_bt.unsqueeze(-1)     #[B, T, 1]
        
        h = torch.tanh(self.proj(x))    #[B, T, hidden_dim]
        a = self.attn(h).squeeze(-1)    #[B, T]
        
        w = torch.nn.functional.softmax(a, 1) # [B, T] Weights for each T, to idetnfy the best frame for classification 
        z = (w.unsqueeze(-1) * x).sum(dim=1) # [B, 1]
        z = self.pooler_cls(z)
        
        return z

class Chomp1d(torch.nn.Module):
    """
    Remove extra padding at the end to maintain causality.
    If you pad (padding) at left, you may need to chomp off the right extra.
    """
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size
    def forward(self, x):
        # x has shape [B, C, T]
        if self.chomp_size == 0:
            return x
        return x[:, :, :-self.chomp_size]
    
class TemporalBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, padding, dropout=0.0):
        """
        A residual block in TCN with two dilated conv layers (same dilation).
        """
        super().__init__()
        self.conv1 = torch.nn.Conv1d(in_channels, out_channels,
                               kernel_size,
                               stride=stride,
                               padding=padding,
                               dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = torch.nn.ReLU()
        self.dropout1 = torch.nn.Dropout(dropout)

        self.conv2 = torch.nn.Conv1d(out_channels, out_channels,
                               kernel_size,
                               stride=stride,
                               padding=padding,
                               dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = torch.nn.ReLU()
        self.dropout2 = torch.nn.Dropout(dropout)

        self.downsample = (torch.nn.Conv1d(in_channels, out_channels, 1)
                           if in_channels != out_channels else None)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        """
        x: [B, in_channels, T]
        returns: [B, out_channels, T]
        """
        out = self.conv1(x)
        out = self.chomp1(out)
        out = self.relu1(out)
        out = self.dropout1(out)

        out = self.conv2(out)
        out = self.chomp2(out)
        out = self.relu2(out)
        out = self.dropout2(out)

        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)
    
class TemporalConvNet(torch.nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=3, dropout=0.0):
        """
        num_inputs: number of input channels (features)
        num_channels: list of output channels per layer, e.g. [64, 64, 128]
        kernel_size: convolution kernel size (e.g. 3)
        dropout: dropout rate in blocks
        """
        super().__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            in_ch = num_inputs if i == 0 else num_channels[i - 1]
            out_ch = num_channels[i]
            dilation = 2 ** i
            # padding should be such that the output has length T (causal)
            padding = (kernel_size - 1) * dilation
            layers.append(
                TemporalBlock(in_ch, out_ch,
                              kernel_size=kernel_size,
                              stride=1,
                              dilation=dilation,
                              padding=padding,
                              dropout=dropout)
            )
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x):
        """
        x: [B, T, C_in]
        returns: [B, C_out_last, T]
        """
        x = x.permute(0, 2, 1) #[B, C_in, T]
        x = self.network (x)  #[B, C_out, T]
        x = x.permute(0, 2, 1) #[B, T, C_out]
        
        return x

class Scalar_TCN(torch.nn.Module):
    def __init__(self, hidden_dim=64, layers=3):
        super().__init__()
        
        self.in_proj = torch.nn.Conv1d(1, hidden_dim, kernel_size=1)
        
        blocks = []
        for i in range(layers):
            d = 2 ** i
            blocks += [
                torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=d, dilation=d),
                torch.nn.ReLU(),
                torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1),
                torch.nn.ReLU(),
            ]
        
        self.tcn = torch.nn.Sequential(*blocks)
        self.head = torch.nn.Linear(hidden_dim, 1)
        
    def forward (self, logits_bt):
        x = logits_bt.unsqueeze(1)  #[B, 1, T] to do the Conv over the channel or the logits values
        x = self.in_proj(x)         #[B, hidden_dim, T]
        x = self.tcn(x)             #[B, hidden_dim, T]
        x = x.mean(dim=2)           #[B, hidden_dim] global average over T
        
        x = self.head(x)            #[B, 1]
        
        return x

class ModuleWrapper(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        
    def forward(self, x, dummy_arg=None):
        assert dummy_arg is not None
        x = self.module(x)
        
        return x
        
"""T-S Multi-task Video input freez encoder (A0_A5)

Returns:
    LightningModule
"""  
class RARP_NVB_DINO_MultiTask_A5_Video(L.LightningModule):
    def __init__(
        self, 
        base_model_path = None,
        lr = 0.0001, 
        wd = 0.01,
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None,
        head_type:int = 0, #None = 0, linear = 1, Attn. Pooling = 2, TCN = 3, Replace Head =4
        chunks_loading:int = 50,
    ):
        super().__init__()
        
        self.lr = lr
        self.wd = wd
        self.chunks = chunks_loading
        self.head_type = head_type
        self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
        
        if base_model_path is not None: 
            self.check_pt = False                       
            self.base_model = RARP_NVB_DINO_MultiTask.load_from_checkpoint(base_model_path)
            self.base_model.eval()
            
            for param in self.base_model.parameters():
                param.requires_grad = False
        else:
            self.check_pt = True
            self.base_model = van.van_b2(pretrained = True, num_classes = 0)
            self.head_type = 5
            num_features = 512
            
        self.base_model_wrapper = ModuleWrapper(self.base_model)
        
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
                
        match(self.head_type):
            case 1:
                #Linear
                self.head = torch.nn.Linear(600, 1)
            case 2:
                #Attn. pooling
                self.head = Scalar_AttnPooling(128)
            case 3:
                #TCN
                self.head = Scalar_TCN(64, 3)
            case 4:
                #replace the head of Base model for a new temporal head
                self.base_model.clasiffier = torch.nn.Identity()
                self.head = TemporalConvNet(1024, [128, 8, 1])
            case 5:
                self.head = TemporalConvNet(num_features, [128, 8, 1])
            case _:
                self.head = None    
    
    def _shared_video_step(self, batch:list[torch.Tensor], val_step:bool = False):
        video, label = batch
        B, T, C, H, W = video.shape
        video = video.float() #[B, T, C, H, W]
        label = label.float() #[B]
        
        chunk_T = self.chunks
        pred_bt = []
        
        def _fn(inp):
            pred = self.base_model(inp)
            return pred
        
        for t0 in tqdm(range(0, T, chunk_T), desc=f"Video Analysis in {chunk_T} chunk", leave=False):
            t1 = min(T, t0 + chunk_T)
            x = video[:, t0:t1].reshape(-1, C, H, W).contiguous(memory_format=torch.channels_last)
            
            if self.check_pt:
                pred = torch_ckp.checkpoint(_fn, x) if not val_step else _fn(x)
            else:
                pred, *_ = _fn(x)
            pred_bt.append(pred.view(B, t1-t0, -1))
            
        pred_video = torch.cat(pred_bt, dim=1)
        
        pred_video = pred_video.flatten(start_dim=1) if self.head_type in [0, 1, 2, 3] else pred_video
                
        pred_video = self(pred_video, val_step) #[B, 1]
        if self.head_type in [4, 5]:
            pred_video = pred_video.mean(dim=1)
        pred_video = pred_video.flatten() #[B] to match labels shape
                    
        predicted_labels = torch.sigmoid(pred_video)
        loss = self.lossFN(pred_video, label)
        
        return loss, label, predicted_labels
        
    def forward(self, data, val_step:bool = True):
        if self.head is None:
            pred_video = data.mean(dim=1)
        else:
            pred_video = self.head(data)
            
        return pred_video
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
         
    def configure_optimizers(self):
        params = self.head.parameters() if not self.check_pt else self.parameters()
        optimizer = torch.optim.AdamW(params, lr=self.lr, weight_decay=self.wd)  #, weight_decay=self.Lambda_L2
        
        return [optimizer]

class RARP_NVB_DINO_MultiTask_A6_Video(L.LightningModule):
    def __init__(
        self, 
        lr = 0.0001, 
        wd = 0.01,
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None,
        head_type:int = 0, #None = 0, linear = 1, Attn. Pooling = 2, TCN = 3, Replace Head =4
        chunks_loading:int = 50,
    ):
        super().__init__()
        
        self.lr = lr
        self.wd = wd
        self.chunks = chunks_loading
        self.head_type = head_type
            
        self.check_pt = True
        self.base_model = van.van_b2(pretrained = True, num_classes = 0)
        self.num_features_base_model = 512
            
        self.base_model_wrapper = ModuleWrapper(self.base_model)
        
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
                
        match(self.head_type):
            case 1:
                self.head = TemporalConvNet(self.num_features_base_model, [128, 8, 1])
            case _:
                self.head = None    
    
    def _shared_video_step(self, batch:list[torch.Tensor], val_step:bool = False):
        video, label = batch
        
        video = video.float() #[B, T, C, H, W]
        label = label.float() #[B]
  
        pred_video = self(video, val_step) #[B, T, 1]
        pred_video = pred_video.mean(dim=1) #[B, 1]
        pred_video = pred_video.flatten() #[B] to match labels shape
                    
        predicted_labels = torch.sigmoid(pred_video)
        loss = self.lossFN(pred_video, label)
        
        return loss, label, predicted_labels
        
    def forward(self, video:torch.Tensor, val_step:bool = True):
        B, T, C, H, W = video.shape
        chunk_T = self.chunks
        pred_bt = []
                
        for t0 in tqdm(range(0, T, chunk_T), desc=f"Video Analysis in {chunk_T} chunk", leave=False): # Loop for each chunk
            t1 = min(T, t0 + chunk_T)
            x = video[:, t0:t1].reshape(-1, C, H, W).contiguous(memory_format=torch.channels_last) # reshape from [B, chunk_T, C, H, W] to [B*chunk_T, C, H, W] and make the tensor GPU optimization
            dummy = torch.ones((), device=x.device, dtype=x.dtype, requires_grad=True)
            pred = torch_ckp.checkpoint(self.base_model_wrapper, x, dummy, use_reentrant=False) #froward to CNN and checkpoint grads
            pred_bt.append(pred.view(B, t1-t0, -1)) #apped to output array and reshape to [B, chunk_T, C, H, W]
            
        pred_video = torch.cat(pred_bt, dim=1) #concat all chunks -> [B, T, C, H, W]
        pred_video = self.head(pred_video)
            
        return pred_video
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
         
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)  #, weight_decay=self.Lambda_L2
        
        return [optimizer]
        
#end Ablation Models

class RARP_CLIP_loss(torch.nn.Module):
    def __init__(self, temperature):
        super().__init__()
        self.temp = temperature
        
    def forward(self, z_s:torch.Tensor, z_t:torch.Tensor):
        logits = torch.matmul(z_s, z_t.t()) / self.temp
        lables = torch.arange(z_s.size(0), device=logits.device)
        
        loss_s2t = torch.nn.functional.cross_entropy(logits, lables)
        loss_t2s = torch.nn.functional.cross_entropy(logits.t(), lables)
        
        return 0.5 * (loss_s2t + loss_t2s)

class RARP_CLIP(L.LightningModule):
    def __init__(
        self,
        student_backbone: str = "",
        teacher_backbone: str = "",
        proj_dim: int = 256,
        embeddings: int = 512,
        temperature: float = 0.07,
        lr: float = 1e-4,
    ):
        super().__init__()
        
        self.save_hyperparameters()
        
        match(student_backbone):
            case "van_b1":
                self.student = van.van_b1(pretrained=False, num_classes=0)
                self.student_dim = 512
            case _:
                raise Exception(f"{student_backbone} Not Implemented")
        
        if len(teacher_backbone) > 0:
            self.teacher = van.van_b2(pretrained=False, num_classes=0)
            self.teacher.load_state_dict(torch.load(teacher_backbone))
            self.teacher_dim = 512
        else:
            self.teacher = van.van_b2(pretrained=True, num_classes=0)
            self.teacher_dim = 512
            
        for p in self.teacher.parameters():
            p.requires_grad = False
            
        self.proj_s = torch.nn.Sequential(
            torch.nn.Linear(self.student_dim, proj_dim),
            torch.nn.LayerNorm(proj_dim),
            torch.nn.GELU(),
            torch.nn.Linear(proj_dim, embeddings)
        )
        
        self.loss_fn = RARP_CLIP_loss(temperature)
        
    def forward(self, data):
        x_s = self.student(data)
        x_s = self.proj_s(x_s)
        x_s = torch.nn.functional.normalize(x_s, dim=-1)
        
        x_t = self.teacher(data)
        x_t = torch.nn.functional.normalize(x_t, dim=-1)
        
        return x_s, x_t
            
    def _shared_step(self, batch):
        img, _ = batch
        z_s, z_t = self(img)
        
        loss = self.loss_fn(z_s, z_t)
        
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self._shared_step(batch)
        
        self.log("train/clip_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        params = (
            list(self.student.parameters()) +
            list(self.proj_s.parameters()) 
            #list(self.proj_t.parameters())    # now included
        )
        
        return torch.optim.AdamW(params, lr=self.hparams.lr)
        
class DecoderBlock(torch.nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv_expand = torch.nn.Conv2d(in_ch, out_ch*4, 3, padding=1)
        self.pixel_shuffle = torch.nn.PixelShuffle(2)
        self.bn1 = torch.nn.BatchNorm2d(out_ch)
        self.act1 = torch.nn.GELU()
        self.conv_refine = torch.nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bn2 = torch.nn.BatchNorm2d(out_ch)
        self.act2 = torch.nn.GELU()

    def forward(self, x):
        x = self.conv_expand(x)
        x = self.pixel_shuffle(x)
        x = self.bn1(x)
        x = self.act1(x)
        
        x = self.conv_refine(x)
        x = self.bn2(x)        
        x = self.act2(x)
        return x
    
class DynamicDecoder_PixelShuffle(torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], drop_out:float = None):
        super().__init__()

        # Ensure the number of hidden channels matches the number of blocks
        assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."

        layers = []
        in_channels = input_channels
        
        # Loop to create the decoder blocks
        for out_channels in hidden_channels:
            layers.append(DecoderBlock(in_channels, out_channels))
            if drop_out is not None:
                layers.append(torch.nn.Dropout(drop_out))
            in_channels = out_channels

        # Final layer to get the output image
        #layers.append(torch.nn.Conv2d(in_channels, output_channels, kernel_size=3, padding=1))
        layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
        #layers.append(torch.nn.Sigmoid())  # To get pixel values between 0 and 1

        # Combine all layers into a Sequential module
        self.decoder = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)    

class RARP_NVB_DINO_MultiTask_Pretrain(RARP_NVB_DINO_MultiTask):
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher = 0.9995, 
        lr = 0.0001, 
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None, 
        SoftAdptAlgo = 0, 
        SoftAdptBeta = 0.1, 
        Teacher_T = 0.04, 
        Student_T = 0.1, 
        intermittent = False,
        pre_train_pth:str = "",
        HParameter = {},
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
        
        #self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04) 
        #self.lr = getNVL(HParameter, "lr", 1.0E-4)
        
        self.student = van.van_b2(pretrained = False, num_classes = 0)
        self.teacher_Features = van.van_b2(pretrained = False, num_classes = 0)
                               
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
        )
        
        if len(pre_train_pth) > 0:
            self.student.backbone.load_state_dict(torch.load(pre_train_pth))
        
        self.teacher_Features.load_state_dict(self.student.state_dict())
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False
            
        self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher)
        self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
        
        self.decoder = DynamicDecoder(input_channels=1024)
        
           
class RARP_MAE(L.LightningModule):
    def __init__(
        self, 
        backbone:str, 
        mask_ratio:float = 0.75,
        patch_size: int = 16,
        #img_size: int = 224,
        lr: float = 1e-4,
        hiden_channels = [512, 256, 128, 64, 32],
        img_mean_std = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]],
        activation_fn = torch.nn.ReLU(True)
    ):
        super().__init__()
        
        self.save_hyperparameters(ignore=["activation_fn", "img_mean_std"])
        
        
        
        if img_mean_std is not None:
            mean, std = img_mean_std
            self.register_buffer("mean_IMG", torch.tensor(mean).view(3, 1, 1))
            self.register_buffer("std_IMG", torch.tensor(std).view(3, 1, 1))
        else:
            self.mean_IMG = None
            self.std_IMG = None
        
        self.loss_fn = torch.nn.L1Loss()
        
        match backbone:
            case "resnet":
                model = torchvision.models.resnet18()
                model.fc = torch.nn.Identity()
                self.encoder = model
                self.encoder_out_dim = 512
            case "van":
                model = van.van_b1()
                model.head = torch.nn.Identity()
                self.encoder = model
                self.encoder_out_dim = 512
            case "levit":
                self.encoder = timm.create_model("levit_192.fb_dist_in1k", pretrained=False, num_classes=0)
                self.encoder_out_dim = 384
                hiden_channels = [384, 256, 128, 64, 32]
            case "van_l1_loss":
                model = van.van_b1()
                model.head = torch.nn.Identity()
                self.encoder = model
                self.encoder_out_dim = 512
                self.loss_fn = torch.nn.L1Loss()
            case "van_2":
                model = van.van_b2(num_classes=0)
                self.encoder = model
                self.encoder_out_dim = 512
                

        in_channel = self.encoder_out_dim
        layers = [
            torch.nn.Linear(in_channel, hiden_channels[0]*7*7),
            torch.nn.Unflatten(1, (hiden_channels[0], 7, 7)),
        ]
        
        for out_channel in hiden_channels[1:]:
            layers.append(torch.nn.ConvTranspose2d(in_channel, out_channel, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(torch.nn.BatchNorm2d(out_channel))
            layers.append(activation_fn)
            in_channel = out_channel
            
        #Last layer
        layers.append(torch.nn.ConvTranspose2d(in_channel, 3, kernel_size=3, stride=2, padding=1, output_padding=1))
                
        self.decoder = torch.nn.Sequential(*layers)
        
        self.angles = [-90, 0, 90, 180]
        self.angles_labels = torch.from_numpy(LabelEncoder().fit_transform(self.angles))
        
        self.classifier = RARP_NVB_Classification_Head(self.encoder_out_dim, len(self.angles), layer=[128], activation_fn=torch.nn.GELU())
        
        self.lossFN_Aux = torch.nn.CrossEntropyLoss()
        
        self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=len(self.angles))
        self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=len(self.angles))
    
    def _rotation_labels_generator(self, batch_size:int, angles_batch=None, rand:bool=False) -> torch.Tensor: 
        assert len(angles_batch) > 0, "Empty list, angles list shuld have more than 2 values"
        
        if rand:
            return torch.tensor([angles_batch[np.random.randint(len(angles_batch))] for _ in range(batch_size)])
        else:
            return torch.tensor([angles_batch[i % len(angles_batch)] for i in range(batch_size)])
        
    def _rotate_img(self, imgs:torch.Tensor, angles_batch=[]) -> torch.Tensor:
        assert len(angles_batch) > 0, "Empty list, angles list shuld have more than 2 values"
        
        list_imgs = []
        
        for i, angle_idx in enumerate(angles_batch):
            list_imgs.append(torchvision.transforms.functional.rotate(imgs[i, ...], self.angles[angle_idx]))
            
        return torch.stack(list_imgs, dim=0)
    
    def _denormalize(self, tensor:torch.Tensor):
        return tensor * self.std_IMG + self.mean_IMG
    
    def _mask_patches(self, imgs):
        B, C, H, W = imgs.shape
        ph, pw = self.hparams.patch_size, self.hparams.patch_size
        gh, gw = H // ph, W // pw
        
        mask:torch.Tensor = (torch.rand(B, gh * gw, device=imgs.device) >= self.hparams.mask_ratio)
        mask = mask.reshape(B, 1, gh, gw)   # [B,1,gh,gw]
        
        # expand mask to full image
        mask = mask.repeat_interleave(ph, dim=2)    # [B,1,gh*ph,gw]
        mask = mask.repeat_interleave(pw, dim=3)    # [B,1,gh*ph,gw*pw] == [B,1,H,W]
        
        mask = mask.expand(-1, C, -1, -1)
        
        imgs_masked = imgs * mask
        return imgs_masked
    
    def forward(self, data, data_rot, val_step=False):
        if not val_step:
            data = [d.float() for d in data] #[0] original Img, [1] augmented Img
        else:
            data = [data.float(), data.float()]
            
        data_aux = data_rot.float()
        
        imgs_masked = self._mask_patches(data[1])
        
        latent = self.encoder(imgs_masked)   # [B, 512]
        aux_latent = self.encoder(data_aux)
        
        reconstructed_image = self.decoder(latent)  # [B, 3, 224, 224]
        pred_rot = self.classifier(aux_latent)
        
        return reconstructed_image, data[0], pred_rot
    
    def _shared_step(self, batch, val_step=False):
        img, _ = batch
        
        batch_size = img.size(0) if val_step else img[0].size(0)
        current_device = img.device if val_step else img[0].device
        labels = self._rotation_labels_generator(batch_size, self.angles_labels, not val_step).to(current_device)
        rot_img = self._rotate_img(img if val_step else img[0], labels).to(current_device)
        
        res, true_img, aux_pred = self(img, rot_img, val_step)
        
        aux_pred = torch.softmax(aux_pred, dim=1)
        loss_aux = self.lossFN_Aux(aux_pred, labels)
        
        loss_MSE =  self.loss_fn(res, true_img)
        
        loss = loss_MSE + loss_aux
        
        return loss, res, true_img, (aux_pred, labels)
    
    def training_step(self, batch, batch_idx):
        loss, img, _, pred = self._shared_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True, sync_dist=True)
        self.train_acc.update(pred[0], pred[1])
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        if batch_idx % 500 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(img), 0, 1)
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
            
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, img, target_img, pred = self._shared_step(batch, True)
                
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.val_acc.update(pred[0], pred[1])
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        
        if self.mean_IMG is not None and self.std_IMG is not None:
            img = torch.clip(self._denormalize(img), 0, 1)
            target_img = torch.clip(self._denormalize(target_img), 0, 1)
            
            ssim = piq.ssim(img, target_img, data_range=1.0)
            psnr = piq.psnr(img, target_img, data_range=1.0)
            
            self.log("val_ssim", ssim, on_epoch=True)
            self.log("val_psnr", psnr, on_epoch=True)
            
            if batch_idx % 100 == 0:
                grid = torchvision.utils.make_grid(img)
                self.logger.experiment.add_image('val_reconstructed_images', grid, self.global_step)
            
                
    def on_after_backward(self):
        norms = [p.grad.data.norm(2).item() for p in self.parameters() if p.grad is not None]
        avg_layer_norm = sum(norms) / len(norms)
        
        self.log("grad_norm", avg_layer_norm)
        
        if avg_layer_norm < 1e-8:
            self.log("grad_warning", "Vanishing gradient suspected!")
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)  
        
        return [optimizer]

class RARP_NVB_DINO_MultiTask_LeViT(RARP_NVB_DINO_MultiTask):
    def _hook_fn_Student(self, module, input, output):
        B, N, C = output.shape
        H = W = int (N**0.5)
        
        self.last_conv_output_S = output.transpose(1, 2) 
        self.last_conv_output_S = self.last_conv_output_S.contiguous().view(B, C, H, W)
        
    def _hook_fn_Teacher(self, module, input, output):
        B, N, C = output.shape
        H = W = int (N**0.5)
        
        self.last_conv_output_T = output.transpose(1, 2) 
        self.last_conv_output_T = self.last_conv_output_T.contiguous().view(B, C, H, W)
        
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
        
        self.in_dim = 768
        self.out_dim = 2048
        
        self.student = timm.create_model("levit_384.fb_dist_in1k", pretrained=True, num_classes=0)
        self.teacher_Features = timm.create_model("levit_384.fb_dist_in1k", pretrained=True, num_classes=0)
        
        self.decoder = DynamicDecoder(input_channels = 1024) 
               
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=3)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=3)
        )
        
        self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
        
        self.teacher_Features.load_state_dict(self.student.state_dict())
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False
            
        self.teacher_Features.backbone.stages[-2].register_forward_hook(self._hook_fn_Teacher)
        self.student.backbone.stages[-2].register_forward_hook(self._hook_fn_Student)
            
        self.clasiffier = RARP_NVB_Classification_Head(1024, 1, [128, 8], torch.nn.SiLU(True))
        
class RARP_NVB_DINO_MultiTask_Unet(RARP_NVB_DINO_MultiTask):
    def _encoder_hool_fn(self, module, input, output):
        self.feature_maps.append(output)
    
    def _register_encoder_hooks(self, block_list:list):
        for layer in block_list:
            self.hooks.append(layer.register_forward_hook(self._encoder_hool_fn))
    
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy,
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        L1:float = None,
        L2:float = 0,
        std: float = None,
        mean: float = None,
        SoftAdptAlgo:int = 0,
        SoftAdptBeta:float = 0.1
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
        
        self.hooks = []
        self.feature_maps = []
        
        self.list_blocks = [
            self.student.backbone.block1[-1],
            self.student.backbone.block2[-1],
            self.student.backbone.block3[-1],
            self.student.backbone.block4[-2],
        ]
        
        #self.list_blocks_T = [
        #    self.teacher_Features.backbone.block1[-1],
        #    self.teacher_Features.backbone.block2[-1],
        #    self.teacher_Features.backbone.block3[-1],
        #    self.teacher_Features.backbone.block4[-2],
        #]
        
        self._register_encoder_hooks(self.list_blocks)
        #self._register_encoder_hooks(self.list_blocks_T)
        
        self.decoder = DecoderUnet(1024)
                
    def forward(self, data, val_step:bool = True):
        self.feature_maps = []
        
        if val_step:
            data = data.float()
            dataTeacher, dataStudent = data, data
        else:
            data = [d.float() for d in data]
            dataTeacher, dataStudent = data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        Student = self.student(dataStudent)
        
        _temp = []
        num_blocks = len(self.list_blocks)
        NumChunks = len(dataStudent)
        for i in range(num_blocks):
            if not val_step:
                S_GlogalViews = torch.cat(self.feature_maps[i + num_blocks].chunk(NumChunks)[1:3], dim=0)
            else:
                S_GlogalViews = self.feature_maps[i + num_blocks]
                
            _temp.append(self.feature_maps[i] * S_GlogalViews)
            
        self.feature_maps = _temp
                       
        if not val_step:
            S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
            self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
        
        cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
                
        self.feature_maps.append(cat_features)
        reconstructed_image = self.decoder(self.feature_maps)
        
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
        pred = self.clasiffier(Cont_Net)
                        
        return pred, (Student, TeacherDino), reconstructed_image

class RARP_NVB_DINO_MultiTask_MultiLabel(RARP_NVB_DINO_MultiTask):
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.BCEWithLogits, 
        momentum_teacher = 0.9995, 
        lr = 0.0001, 
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None, 
        SoftAdptAlgo = 0, 
        SoftAdptBeta = 0.1,
        Num_Lables = 2
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
        
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(1024, 128),
            torch.nn.Dropout(0.4),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(128, 8),
            torch.nn.Dropout(0.2),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(8, Num_Lables)
        )
        
        self.train_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
        self.val_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
        self.test_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
        self.f1ScoreTest = torchmetrics.F1Score('multilabel', num_labels=Num_Lables)

class RARP_NVB_DINO_MultiTask_v2(RARP_NVB_DINO_MultiTask):
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher = 0.9995, 
        lr = 0.0001, 
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean)
            
        self.train_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.val_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.test_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.f1ScoreTest = torchmetrics.F1Score('multiclass', num_classes=4)
        
        self.lossFN = torch.nn.CrossEntropyLoss()
        
        self.softAdapt =  LossWeightedSoftAdapt(0.1) #NormalizedSoftAdapt(0.1)
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(1024, 512),
            torch.nn.Dropout(0.2),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(512, 256),
            #torch.nn.Dropout(0.4),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(256, 128),
            #torch.nn.Dropout(0.4),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(128, 8),
            #torch.nn.Dropout(0.2),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(8, 4)
        )
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        prediction, features, new_image = self(img, val_step)
        StudentF, TeacherF = features
        
        predicted_labels = torch.softmax(prediction,dim=1)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label for _ in range(len(TeacherF))], dim=0) if not val_step else label
        
        #DINO Loss
        loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(predicted_labels, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_l1 = 0
                for params in self.clasiffier.parameters(): # aqui
                    loss_l1 += torch.sum(torch.abs(params))
                loss_HL += self.Lambda_L1 * loss_l1
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_DINO"].append(loss_Dino.item())
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
        
        return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)

class RARP_NVB_RN50_VAN_V2 (RARP_NVB_Model):
    # Define a hook function to capture the output
    def _hook_fn(self, module, input, output):
        self.last_conv_output = output
        
    def __init__(
        self, 
        PseudoEstimator:str = None, 
        threshold:float = 0.5, 
        InitWeight=None, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        PseudoLables:bool=True,
        HParameter = {},
        std: float = None,
        mean: float = None,
        **kwargs
    ) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
        self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
        
        #self.RARP_RestNet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50') 
        self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
        self.RARP_RestNet50.model.fc = torch.nn.Identity()  
        for parms in self.RARP_RestNet50.model.parameters():
            parms.requires_grad = False      
            
        self.RARP_RestNet50 = torch.nn.Sequential(*list(self.RARP_RestNet50.model.children())[:-2])
          
        self.decoder = Decoder()
        
        self.FeaturesLoss = FeatureAlignmentLoss()
        self.ReconstructionLoss = torch.nn.MSELoss()
    
        match (TypeError):
            case TypeLossFunction.ContrastiveLoss:
                self.lossFN = ContrastiveLoss()
            case _:
                pass            
        
        self.model = van.van_b2(pretrained = True, num_classes = 1)  
              
        self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04) 
        self.lr = getNVL(HParameter, "lr", 1.0E-4)
                
        print(f"lr={self.lr}, L1={self.Lambda_L1}")
        
        # Initialize a variable to store the output
        self.last_conv_output = None
        self.model.block4[-1].mlp.dwconv.register_forward_hook(self._hook_fn)
        

    def forward(self, data):
        dataTeacher, dataStudent, _ = data
        dataStudent = dataStudent.float()
        dataTeacher = dataTeacher.float()
        
        feature_Teacher = self.RARP_RestNet50 (dataTeacher)
        #feature_Student = self.model.forward_features(dataStudent)
        pred = self.model(dataStudent)  
        feature_Student = self.last_conv_output
        
        #feature_Teacher = torch.nn.functional.adaptive_avg_pool1d(feature_Teacher, feature_Student.size(-1)) #Re-size output vector to macth VAN 
        
        cat_features = (feature_Student + feature_Teacher) / 2
                        
        reconstructed_image = self.decoder(cat_features) 
        
        feature_Student = torch.nn.functional.adaptive_avg_pool2d(feature_Student, (1, 1)).flatten(1)
        feature_Teacher = torch.nn.functional.adaptive_avg_pool2d(feature_Teacher, (1, 1)).flatten(1)
                
        return pred, (feature_Student, feature_Teacher), reconstructed_image
    
    def _shared_step(self, batch):
        img, label = batch
        label = label.float()
        orignalImg = img[2].float()
                
        prediction, features, new_image = self(img)
        
        prediction = prediction.flatten()
        predicted_labels = torch.sigmoid(prediction)
        
        #cosine similarity
        loss_cosine = self.FeaturesLoss(features[0], features[1])
        #Clasificator
        loss_hl = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()
        
        loss = loss_cosine + loss_hl + loss_img
        
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.model.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
            
        return loss, label, predicted_labels, (loss_cosine, loss_hl, loss_img, new_image)
    
    def _denormalize(self, tensor:torch.Tensor):
        # Move mean and std to the same device as the input tensor
        mean = self.mean_IMG.to(tensor.device)
        std = self.std_IMG.to(tensor.device)
        return tensor * std + mean
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("train_loss_cosSim", losses[0], on_epoch=True, on_step=False)
        self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
        
        if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("val_loss_cosSim", losses[0], on_epoch=True, on_step=False)
        self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
        if self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
        
        return optimizer
                
class RARP_NVB_ResNet50_VAN(RARP_NVB_Model):
    def __init__(
        self, 
        PseudoEstimator:str = None, 
        threshold:float = 0.5, 
        InitWeight=None, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        PseudoLables:bool=True,
        HParameter = {},
        **kwargs
    ) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
        self.threshold = threshold
        self.PseudoLables = PseudoLables
        
        match (TypeError):
            case TypeLossFunction.ContrastiveLoss:
                self.lossFN = ContrastiveLoss()
            case _:
                pass            
            
        for parms in self.RARP_RestNet50.model.parameters():
            parms.requires_grad = False
            
        self.model = van.van_b2(pretrained = True, num_classes = 1)        
        self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04) 
        self.lr = getNVL(HParameter, "lr", 1.0E-4)
        self.W_Apha = getNVL(HParameter, "Alpha", 0.60)
        self.W_Beta = 1 - getNVL(HParameter, "Alpha", 0.60)
        self.thao_KD = getNVL(HParameter, "Thao", 5)
        
        print(f"A={self.W_Apha}; B={self.W_Beta}, T={self.thao_KD}, lr={self.lr}, L1={self.Lambda_L1}")
        
        #self.Lambda_L1 = None
        #self.scheduler = True
        #self.lr = 1.74E-2
        
    def forward(self, data):
        #dataTeacher, dataStudent = data
        if isinstance(data, tuple):
            dataTeacher, dataStudent = data
        elif isinstance(data, torch.Tensor):
            dataTeacher, dataStudent = (data, data)
        RN50Pred = torch.sigmoid(self.RARP_RestNet50(dataTeacher).flatten())
        PseudoLabels = (RN50Pred > self.threshold) * 1.0
        
        dataStudent = dataStudent.float()
        pred = self.model(dataStudent)
                
        return pred, PseudoLabels, RN50Pred
        
    def _shared_step(self, batch):
        img, label = batch
        if self.InitWeight is not None:
            self.lossFN.weight = self.InitWeight[label]
        
        label = label.float()
        prediction, PseudoLabels, predictionRN50 = self(img) #.flatten()
        prediction = prediction.flatten()
        predicted_labels = torch.sigmoid(prediction)
         #
        # L[B] => BCEWithLogits
        # L[C] => ContrastiveLoss
        #Loss L[1] = L[B_HL] + L[B_PL]
        #Loss L[2] = L[C_HL]
        #Loss L[3] = L[C_HL] + L[C_PL]
        #Loss L[4] = L[C_y_hat]
        #Loss L[5] = L[KLD_PL] + L[B_PL]
        #Loss L[6] = L[MSE_PL] + L[B_PL]
        #Loss L[7] = L[B_HL] + L[BCE_PL]
        #Loss L[8] = FocalLoss*L[JSD]
        #Nuveo
        #thao_KD = 5.0
        #w_alpha, w_beta = (0.60, 0.40)
        
        #L[7]:
        softTeacher = torch.sigmoid(predictionRN50/self.thao_KD)
        softStudent = torch.sigmoid(prediction/self.thao_KD)
        
        loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
        loss_hl = self.lossFN(prediction, label)
        
        loss = self.W_Apha * loss_hl + self.W_Beta * loss_sl #/ (self.thao_KD ** 2)
        
        #loss = w_alpha * self.lossFN(prediction, label) + w_beta * self.lossFN(prediction, PseudoLabels) #L[1]
        #loss = w_alpha * (torch.nn.KLDivLoss()(softStudent, softTeacher) * (thao_KD ** 2)) + w_beta * self.lossFN(prediction, PseudoLabels) #L[5]
        #loss = w_alpha * torch.nn.functional.mse_loss(softStudent, softTeacher) + w_beta * self.lossFN(prediction, PseudoLabels) #L[6]
        #loss = self.lossFN(predictionRN50, predicted_labels, label) #L[2]
        #loss = self.lossFN(predictionRN50, predicted_labels, label) + self.lossFN(predictionRN50, predicted_labels, PseudoLabels) #L[3]
        #y_hat = (PseudoLabels != label) * 1
        #loss = self.lossFN(predictionRN50, predicted_labels, y_hat) #L[4]
        
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.model.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
        
        return loss, (PseudoLabels if self.PseudoLables else label), predicted_labels, (loss_hl, loss_sl)
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_lh", losses[0], on_step=False, on_epoch=True)
        self.log("train_ld", losses[1], on_step=False, on_epoch=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_lh", losses[0], on_step=False, on_epoch=True)
        self.log("val_ld", losses[1], on_step=False, on_epoch=True)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, _ = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
        if self.scheduler:
            scheduler = CosineAnnealingLR(
                optimizer, 
                max_epochs=150,
                warmup_epochs=8,
                warmup_start_lr=0.001,
                eta_min=0.0001
            )
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        return optimizer

class RARP_NVB_SSL_RestNet50_Deep(RARP_NVB_ResNet50_VAN):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        InitWeight=None, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        PseudoLables: bool = True,
        **kwargs
    ) -> None:
        super().__init__(None, threshold, InitWeight, TypeLoss, PseudoLables, **kwargs)
        
        self.RARP_RestNet50 = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        #self.model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
        tempFC_ft = 2048
        self.model.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features=tempFC_ft, out_features=128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Linear(8, 1)
        )
        
        for parms in self.RARP_RestNet50.model.parameters():
            parms.requires_grad = False

class RARP_NVB_MobileNetV2(RARP_NVB_Model):#class RARP_NVB_MobileNetV2(RARP_NVB_Model_BCEWithLogitsLoss):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

class RARP_NVB_EfficientNetV2(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
class RARP_NVB_Vit_b_16(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
        super().__init__(InitWeight, TypeLoss)
        
        self.model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        tempFC_ft = self.model.heads[0].in_features 
        self.model.heads[0] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
class RARP_NVB_DenseNet169(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
        super().__init__(InitWeight, TypeLoss)
        
        self.model = torchvision.models.densenet169(weights=torchvision.models.DenseNet169_Weights.DEFAULT)
        inFeatures = self.model.classifier.in_features
        self.model.classifier = torch.nn.Linear(in_features=inFeatures, out_features=1)

class RARP_NVB_RestNet50_old(L.LightningModule):
    def __init__(self, InitialWeigth = 1) -> None:
        super().__init__()

        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4) 
        self.lossFN = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([InitialWeigth]))

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')

    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def training_step(self, batch, batch_idx):
        img, label = batch
        label = label.float()
        prediction = self(img)[:,0]
        loss = self.lossFN(prediction, label)

        self.log("Train Loss", loss)
        self.log("Step Train Acc", self.train_acc(torch.sigmoid(prediction), label.int()))

        return loss
    
    def on_train_epoch_end(self):
        self.log("Train Acc", self.train_acc.compute())

    def validation_step(self, batch, batch_idx):
        img, label = batch
        label = label.float()
        prediction = self(img)[:,0]
        loss = self.lossFN(prediction, label)

        self.log("Val Loss", loss)
        self.log("Step Val Acc", self.val_acc(torch.sigmoid(prediction), label.int())) #train_acc

        return loss
    
    def on_validation_epoch_end(self):
        self.log("Val Acc", self.val_acc.compute())

    def configure_optimizers(self):
        return [self.optimizer]