Newer
Older
RARP / Models.py
@delAguila delAguila on 10 Mar 2025 107 KB 25-03-10 T-M best A0
import math
from typing import Any, Union
import torch
import torchvision
import torchmetrics
import torchmetrics.classification
import lightning as L
from enum import Enum
import timm
import van
import numpy as np
from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt
from noah import NOAH


def js_divergence_sigmoid(p_logits, q_logits):
    p_probs = torch.sigmoid(p_logits)
    q_probs = torch.sigmoid(q_logits)
    
    m = 0.5 * (p_probs + q_probs)
    
    bce_p_m = torch.nn.functional.binary_cross_entropy(m, p_probs, reduction='none')
    bce_q_m = torch.nn.functional.binary_cross_entropy(m, q_probs, reduction='none')
    
    js_div = 0.5 * (bce_p_m + bce_q_m)
    
    return js_div

def getNVL(obj:dict, key, default):
    return default if obj.get(key) is None else obj.get(key)

class Decoder (torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
        
        self.Activation = torch.nn.ReLU#torch.nn.GELU
        self.input_channels = input_channels
        
        blocks = []
        
        inCh = input_channels
        
        blocks.append(torch.nn.Conv2d(inCh, inCh, kernel_size=3, stride=1, padding=1, bias=False))
        blocks.append(torch.nn.BatchNorm2d(inCh))
        blocks.append(self.Activation())
        
        for i, outCh in enumerate(hidden_channels):
            blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=4, stride=2, padding=1, bias=False))
            blocks.append(torch.nn.BatchNorm2d(outCh))
            blocks.append(self.Activation())
            blocks.append(torch.nn.Conv2d(outCh, outCh, kernel_size=3, stride=1, padding=1, bias=False))
            blocks.append(torch.nn.BatchNorm2d(outCh))
            blocks.append(self.Activation())
            
            #blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=3, stride=2, padding=1, output_padding=1))
            #blocks.append(torch.nn.BatchNorm2d(outCh))
            #blocks.append(self.Activation())
            inCh = outCh
        
        blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=4, stride=2, padding=1, bias=False))
        blocks.append(torch.nn.BatchNorm2d(output_channels))
        blocks.append(self.Activation())
        blocks.append(torch.nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=False))
        blocks.append(torch.nn.BatchNorm2d(output_channels))
        blocks.append(self.Activation())
        
        #blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
                
        self.decoder = torch.nn.Sequential(*blocks)
        
    def forward(self, x):
        
        x = self.decoder(x)
        return x
    
class DynamicDecoder(torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64]):
        super(DynamicDecoder, self).__init__()

        # Ensure the number of hidden channels matches the number of blocks
        assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."

        layers = []
        in_channels = input_channels
        
        # Loop to create the decoder blocks
        for out_channels in hidden_channels:
            layers.append(torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(torch.nn.BatchNorm2d(out_channels))
            layers.append(torch.nn.ReLU(inplace=True))
            in_channels = out_channels

        # Final layer to get the output image
        layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
        #layers.append(torch.nn.Sigmoid())  # To get pixel values between 0 and 1

        # Combine all layers into a Sequential module
        self.decoder = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)    
    
class DecoderUnet(torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3):
        super().__init__()

        self.dropout = torch.nn.Dropout2d(0.4)

        self.upConv_0 = torch.nn.ConvTranspose2d(input_channels, 512, kernel_size=2, stride=2)
        self.decoder_0 = self._conv_block(1024, 512) 
        
        self.upConv_1 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2)
        self.decoder_1 = self._conv_block(640, 320)
        
        self.upConv_2 = torch.nn.ConvTranspose2d(320, 128, kernel_size=2, stride=2)
        self.decoder_2 = self._conv_block(256, 128)
        
        self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder_3 = self._conv_block(128, 64) 
        
        self.upConv_4 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder_4 = self._conv_block(32, 16)
        
        self.enc3_upsample = torch.nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
        self.enc2_upsample = torch.nn.ConvTranspose2d(320, 320, kernel_size=2, stride=2)
        self.enc1_upsample = torch.nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
        self.enc0_upsample = torch.nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        
        self.last_conv = torch.nn.Conv2d(16, output_channels, kernel_size=1)
        
    def _conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
        )

    def forward(self, x):
        encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = x
                
        decoder_l3 = self.upConv_0(btlneck)
        encoder_l3 = self.enc3_upsample(encoder_l3)
        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
        decoder_l3 = self.decoder_0(decoder_l3)
        
        decoder_l3 = self.dropout(decoder_l3)
        
        decoder_l2 = self.upConv_1(decoder_l3)
        encoder_l2 = self.enc2_upsample(encoder_l2)
        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
        decoder_l2 = self.decoder_1(decoder_l2)
        
        decoder_l2 = self.dropout(decoder_l2)
        
        decoder_l1 = self.upConv_2(decoder_l2)
        encoder_l1 = self.enc1_upsample(encoder_l1)
        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
        decoder_l1 = self.decoder_2(decoder_l1)
        
        decoder_l1 = self.dropout(decoder_l1)
        
        decoder_l0 = self.upConv_3(decoder_l1)
        encoder_l0 = self.enc0_upsample(encoder_l0)
        decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1)
        decoder_l0 = self.decoder_3(decoder_l0)
        
        decoder_last = self.upConv_4(decoder_l0)
        decoder_last = self.decoder_4(decoder_last)
        
        return self.last_conv(decoder_last)           

class ModelsList(Enum):
    RestNet50 = 1
    DenseNet169 = 2
    Efficientnet_b0 = 3
    MobileNetV2 = 4
    Inception3 = 5
    ResNeXt_50_32x4d = 6
    RestNet50_Droput = 7

class TypeLossFunction(Enum):
    CrossEntropy = 0
    BCEWithLogits = 1
    HingeLoss = 2
    FocalLoss = 3
    ContrastiveLoss = 4
    
    
class FeatureAlignmentLoss(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def forward(self, F_student, F_teacher):
        fS = torch.nn.functional.normalize(F_student, p=2, dim=-1)
        fT = torch.nn.functional.normalize(F_teacher, p=2, dim=-1)
        
        cos_Sim = torch.sum(fS * fT, dim=-1)
        
        loss = 1 - cos_Sim
        
        return loss.mean()
    
class CosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_epochs: int,
        max_epochs: int,
        warmup_start_lr: float = 0.00001,
        eta_min: float = 0.00001,
        last_epoch: int = -1,
    ):
        """
        Args:
            optimizer (torch.optim.Optimizer):
                最適化手法インスタンス
            warmup_epochs (int):
                linear warmupを行うepoch数
            max_epochs (int):
                cosine曲線の終了に用いる 学習のepoch数
            warmup_start_lr (float):
                linear warmup 0 epoch目の学習率
            eta_min (float):
                cosine曲線の下限
            last_epoch (int):
                cosine曲線の位相オフセット
        学習率をmax_epochsに至るまでコサイン曲線に沿ってスケジュールする
        epoch 0からwarmup_epochsまでの学習曲線は線形warmupがかかる
        https://pytorch-lightning-bolts.readthedocs.io/en/stable/schedulers/warmup_cosine_annealing.html
        """
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.warmup_start_lr = warmup_start_lr
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch, verbose=False)
        return None

    def get_lr(self):
        if self.last_epoch == 0:
            return [self.warmup_start_lr] * len(self.base_lrs)
        if self.last_epoch < self.warmup_epochs:
            return [
                group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]
        if self.last_epoch == self.warmup_epochs:
            return self.base_lrs
        if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
            return [
                group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]

        return [
            (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
            / (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)))
            * (group["lr"] - self.eta_min)
            + self.eta_min
            for group in self.optimizer.param_groups
        ]    

class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=1.0, distance:int = 0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.TypeDistance = distance

    def forward(self, output1, output2, label):
        # Calcula la distancia euclidiana entre las dos salidas
        Distance = torch.nn.functional.pairwise_distance(output1, output2)
        
        # Calcula la pérdida contrastiva
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(Distance, 2) +
            (label) * torch.pow(torch.clamp(self.margin - Distance, min=0.0), 2)
        )
        
        return loss_contrastive    

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook()
        
    def hook(self):
        def forward_hook(module, Input, Output):
            self.activations = Output
        
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]
            
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)
        
    def generate_cam(self, data:torch.Tensor, target_class):
        self.model.zero_grad()
        data.requires_grad= True
        output = self.model(data).flatten()
        loss = torch.nn.functional.binary_cross_entropy_with_logits(output, target_class)
        loss.backward()
        
        pooled_Grad = torch.mean(self.gradients, dim=[0, 2, 3])
        for i in range (pooled_Grad.size(0)):
            self.activations[:, i, :, :] *= pooled_Grad[i]
            
        cam = torch.mean(self.activations, dim=1).squeeze()
        cam = np.maximum(cam.detach().numpy(), 0)
        cam = (cam - cam.min()) / (cam.max() - cam.min())
        return cam

class UNet_RN18(torch.nn.Module):
    def _conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
        )
    
    def _hook_fn(self, module, input, output):
        self.feature_maps.append(output)    
    
    def _register_encoder_hooks(self):
        for layer in self.list_blocks:
            self.hooks.append(layer.register_forward_hook(self._hook_fn))
    
    def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.hooks = []
        self.feature_maps = []
        
        self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        self.encoder_base.fc = torch.nn.Identity()
        
        #for parms in self.encoder_base.parameters():
        #    parms.requires_grad = False
        
        self.list_blocks = [
            self.encoder_base.conv1,
            self.encoder_base.layer1,
            self.encoder_base.layer2,
            self.encoder_base.layer3,
            self.encoder_base.layer4
        ]
        
        self._register_encoder_hooks()
        
        self.dropout = torch.nn.Dropout2d(0.4)
        
        self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        
        self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)

        self.decoder_0 = self._conv_block(512, 256)     #0
        self.decoder_1 = self._conv_block(256, 128)     #1
        self.decoder_2 = self._conv_block(128, 64)      #2
        self.decoder_3 = self._conv_block(32 + 64, 32)  #3
        
        self.decoder_extra = self._conv_block(16, 16)
        
        self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1)
        
    def forward(self, x):
        self.feature_maps = []
        
        _ = self.encoder_base(x) # forward to encoder and call hooks 
        
        encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps
                
        decoder_l3 = self.upConv_0(btlneck) 
        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) 
        decoder_l3 = self.decoder_0(decoder_l3)    
        
        decoder_l3 = self.dropout(decoder_l3)
        
        decoder_l2 = self.upConv_1(decoder_l3) 
        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) 
        decoder_l2 = self.decoder_1(decoder_l2)  
        
        decoder_l2 = self.dropout(decoder_l2)
        
        decoder_l1 = self.upConv_2(decoder_l2) 
        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) 
        decoder_l1 = self.decoder_2(decoder_l1)  
        
        decoder_l1 = self.dropout(decoder_l1)
        
        decoder_l0 = self.upConv_3(decoder_l1) 
        decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1) 
        decoder_l0 = self.decoder_3(decoder_l0)
        
        decoder_last = self.upConv_extra(decoder_l0)
        decoder_last = self.decoder_extra(decoder_last)
                
        return self.last_conv(decoder_last)
    
class RARP_NVB_ROI_Mask_Unet(L.LightningModule):
    def __init__(self,*args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.model = UNet_RN18(in_channels=3, out_channels=1)
               
        self.lr = 1E-4
        self.Lambda_L1 = None
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.train_IoU = torchmetrics.classification.BinaryJaccardIndex()
        self.val_IoU = torchmetrics.classification.BinaryJaccardIndex()
                
    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch, val_step:bool = True):
        img, mask = batch
        
        mask = mask.float()
        mask = mask.unsqueeze(1)
        prediction = self(img)
                
        loss = self.lossFN(prediction, mask)
        
        predicted_labels = torch.sigmoid(prediction)
        
        if not val_step:
            if self.Lambda_L1 is not None:
                loss_l1 = 0
                for name, params in self.model.named_parameters():
                    if "decoder" in name or "upConv" in name: 
                        loss_l1 += torch.norm(params, p=1)
                    
                loss += self.Lambda_L1 * loss_l1
            
        return loss, mask, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, False)

        self.train_IoU.update(predicted_labels, true_labels)
        
        self.log("train_loss", loss, on_epoch=True)
        self.log("train_acc_IoU", self.train_IoU, on_epoch=True, on_step=False)

        return loss
    
    def on_after_backward(self):
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        self.log("grad_norm", total_norm)
        
        if total_norm < 1e-8:
            self.log("grad_warning", "Vanishing gradient suspected!")
    
    def on_train_epoch_start(self):
        for parms in self.model.encoder_base.parameters():
            parms.requires_grad = (self.current_epoch % 2 == 0)
            
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        
        self.val_IoU.update(predicted_labels, true_labels)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_acc_IoU", self.val_IoU, on_epoch=True, on_step=False)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) 
                
        return [optimizer]

class RARP_NVB_ResNet50_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.resnet50()
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) 
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        pred = self.model.fc(Cont_Net)
        
        return pred, featureMap

class RARP_NVB_VAN_CAM(L.LightningModule): #TODO
    def __init__(self) -> None:
        super().__init__()
        
        self.model = van.van_b2(pretrained = True)
        tempFC_ft = self.model.head.in_features
        self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
        
    def forward(self, data, label:torch.Tensor):
        
        cams = GradCAM(self.model, self.model.block4[-1])
        featureMap = cams.generate_cam(data, label)
              
        pred = self.model(featureMap)
        
        
        return pred, featureMap
    
class RARP_NVB_ResNet18_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.resnet18()
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) 
        Cont_Net = torch.flatten(Cont_Net)
        
        pred = self.model.fc(Cont_Net)
        
        return pred, featureMap
    
class RARP_NVB_MobileNetV2_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.mobilenet_v2()
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

        self.feature_map = self.model.features
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        pred = self.model.classifier(Cont_Net)
        
        return pred, featureMap
    
class RARP_NVB_EfficientNetV2_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

        self.feature_map = self.model.features
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        pred = self.model.classifier(Cont_Net)
        
        return pred, featureMap

class RARP_NVB_Model_BCEWithLogitsLoss(L.LightningModule):
    def __init__(self, x=None) -> None:
        super().__init__()
        
        self.model = None
        
        self.lossFN = torch.nn.BCEWithLogitsLoss() #pos_weight=torch.tensor([2.73])
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1Score = torchmetrics.F1Score('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
    def forward(self, data):
        data = data.float()
        return self.model(data)
    
    def _shared_step(self, batch):
        img, label = batch
        label = label.float()
        pred = self.forward(img).flatten()
        loss = self.lossFN(pred, label)
        
        predicted_labels = torch.sigmoid(pred)
        
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        
        self.log("train_loss", loss)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss)
        self.val_acc(predicted_labels, true_labels)
        self.f1Score(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc(predicted_labels, true_labels)
        self.f1ScoreTest(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4) 
        return [optimizer]

class RARP_NVB_FOCAL_loss(torch.nn.Module):
    def __init__(self, alpha: float = 0.25, gamma: float = 2, reduction: str = "mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction 
        
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return torchvision.ops.focal_loss.sigmoid_focal_loss(input, target, self.alpha, self.gamma, self.reduction)

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Convert targets to one-hot encoding
        targets = torch.nn.functional.one_hot(targets, num_classes=inputs.size(1)).float()

        # Compute softmax over the inputs
        probs = torch.nn.functional.softmax(inputs, dim=1)
        log_probs = torch.nn.functional.log_softmax(inputs, dim=1)

        # Compute the focal loss components
        focal_weight = (1 - probs) ** self.gamma
        loss = -self.alpha * focal_weight * targets * log_probs

        # Apply reduction method
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
#TODO
class RARP_NVB_MultiClassModel(L.LightningModule):
    def __init__(self, 
                 InitWeight = torch.tensor([1,1]),  
                 schedulerLR:bool=False, 
                 lr:float = 1e-4,
                 Model:torch.nn.Module = None,
                 Num_Classes:int = 2,
                 L1:float = 1.31E-04,
                 L2:float = 0
        ) -> None:
        super().__init__()
        
        self.model = Model
        self.lossFN = FocalLoss() #torch.nn.CrossEntropyLoss()
        self.InitWeight = InitWeight
        self.scheduler = schedulerLR
        self.lr = lr
        self.Lambda_L1 = L1
        self.Lambda_L2 = L2
        self.num_classes = Num_Classes

        self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
        self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
        self.test_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
        self.f1ScoreTest = torchmetrics.F1Score("multiclass", num_classes=Num_Classes)
        
        self.val_loop = False

    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch):
        img, label = batch
        prediction = self(img) 
        predicted_labels = torch.softmax(prediction, dim=1)
        loss = self.lossFN(prediction, label)
        
        if self.Lambda_L1 is not None and not self.val_loop:
            loss_l1 = 0
            for params in self.model.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
       
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        self.val_loop = False
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        self.val_loop = True
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)

    def test_step(self, batch, batch_idx):
        self.val_loop = True
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2) 
        if self.scheduler:
            #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        else:
            return [optimizer]

class RARP_NVB_Model(L.LightningModule):
    def __init__(self, 
                 InitWeight = torch.tensor([1,1]), 
                 typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy, 
                 schedulerLR:bool=True, 
                 lr:float = 1e-4
                ) -> None:
        super().__init__()

        self.model = None
        self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
        self.InitWeight = InitWeight
        self.scheduler = schedulerLR
        self.lr = lr
        self.Lambda_L1 = None #1.31E-04 #
        self.Lambda_L2 = 0
        
        print (f"LR= {self.lr}, L1= {self.Lambda_L1}")

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        #self.f1Score = torchmetrics.F1Score('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')

    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch):
        img, label = batch
        if self.InitWeight is not None:
            self.lossFN.weight = self.InitWeight[label]
        
        label = label.float()
        prediction = self(img)[:,0] #.flatten()
        loss = self.lossFN(prediction, label)
        
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
        
        predicted_labels = torch.sigmoid(prediction)
        
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        #self.f1Score.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("hp_metric", self.val_acc, on_step=False, on_epoch=True)
        #self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2) 
        if self.scheduler:
            #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        else:
            return optimizer
    
class RARP_Ensemble(L.LightningModule):
    def __init__(self, ListModels,  
                 InitWeight = torch.tensor([1,1]), 
                 typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy, 
                 schedulerLR:bool=False, 
                 lr:float = 1e-4) -> None:
        super().__init__()
        
        self.ListModels = ListModels
        #for m in self.ListModels:
        #    m.freeze()
        input_p = len(self.ListModels)
        
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_p, out_features=128),
            torch.nn.SiLU(),
            #torch.nn.Dropout(0.1), #0.5
            torch.nn.Linear(128, 1),
            torch.nn.Sigmoid()
        )
        
        self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCELoss(reduction="sum")
        self.InitWeight = InitWeight
        self.scheduler = schedulerLR
        self.lr = lr

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1Score = torchmetrics.F1Score('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
    def forward(self, data):
        data = data.float()
        p = [m(data) for m in self.ListModels]
        p = torch.cat(p, dim=1)
        x = self.classifier(p)
        return x
    
    def _shared_step(self, batch):
        img, label = batch
        if self.InitWeight is not None:
            self.lossFN.weight = self.InitWeight[label]
        
        label = label.float()
        prediction = self(img)[:,0] #.flatten()
        loss = self.lossFN(prediction, label)
        predicted_labels = prediction
        
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.f1Score.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 
        if self.scheduler:
            #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, 1e-8, verbose=True)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        else:
            return [optimizer]
    
class RARP_NVB_ResNet18(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

class RARP_NVB_DaVit(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = timm.create_model("davit_small.msft_in1k", pretrained=True, num_classes=1)   

class RARP_NVB_EfficientNetV2_Deep(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
        
        self.model.classifier.append(torch.nn.SiLU(True))
        self.model.classifier.append(torch.nn.Linear(128, 8))
        self.model.classifier.append(torch.nn.SiLU(True))
        self.model.classifier.append(torch.nn.Linear(8, 1))

class RARP_NVB_ResNet50_Deep(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features
        self.model.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features=tempFC_ft, out_features=128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Linear(8, 1)
        )
        
    def forward(self, img):
        img = img.float()
        pred = self.model(img)
        return pred
    
class RARP_NVB_ResNet50_Deep_OPTuna(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, output_dims=[128, 8], dropuot=0.2, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        layers = []
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
                
        inputDim = self.model.fc.in_features
        
        for outputDim in output_dims:
            layers.append(torch.nn.Linear(inputDim, outputDim))
            layers.append(torch.nn.SiLU(True))
            layers.append(torch.nn.Dropout(dropuot))
            inputDim = outputDim
        
        layers.append(torch.nn.Linear(inputDim, 1))
        
        self.model.fc = torch.nn.Sequential(*layers)
        
    def forward(self, img):
        img = img.float()
        pred = self.model(img)
        return pred

class RARP_NVB_ResNet50_V2(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features
        self.model.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features=tempFC_ft, out_features=128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8)
        )
        self.model.fc2 = torch.nn.Linear(8 + InputNeurons, 1)
        
    def forward(self, data):
        img, extra = data
        img = img.float()
        x = torch.nn.functional.silu(self.model(img), True)
        extradata = torch.concat((x, extra), dim=1)
        pred = self.model.fc2(extradata)
        
        return pred
    
class RARP_NVB_ResNet50_V3(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
                
        self.model.extraFC = torch.nn.Linear(InputNeurons, 128)
        
        self.model.fc2 = torch.nn.Linear(256, 1)
        
    def forward(self, data):
        img, extra = data
        img = img.float()
        
        x = torch.nn.functional.silu(self.model(img), True)
        y = torch.nn.functional.silu(self.model.extraFC(extra), True)
        x = torch.concat((x, y), dim=1)
        
        pred = self.model.fc2(x)
                
        return pred

class RARP_NVB_ResNet50_V1(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features + InputNeurons
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
                
        self.Decoder = torch.nn.Sequential(*list(self.model.children())[:-2])
        
    def forward(self, data):
        img, extra = data
        img = img.float()
        featureMap = self.Decoder(img)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) 
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        Cont_Net = torch.concat((Cont_Net, extra), dim=1)
        
        pred = self.model.fc(Cont_Net)
                
        return pred
        
class RARP_NVB_VAN(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = van.van_b2(pretrained = True)
        tempFC_ft = self.model.head.in_features
        self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)        
        
class RARP_NVB_ResNet50(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

class RARP_NVB_MLP(torch.nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=512, bottleneck=256, n_layers=3, norm_last_layer=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.activationFN = torch.nn.GELU()
        
        if n_layers == 1:
            self.mlp = torch.nn.Linear(in_dim, bottleneck)
        else:
            layers = [torch.nn.Linear(in_dim, hidden_dim)]
            layers.append(self.activationFN)
            layers.append(torch.nn.Dropout(0.30))
            for _ in range(n_layers - 2):
                layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
                layers.append(self.activationFN)
            layers.append(torch.nn.Linear(hidden_dim, bottleneck))
            self.mlp = torch.nn.Sequential(*layers)
            
        self.apply(self._init_weights)
        self.last_layer = torch.nn.utils.weight_norm(
            torch.nn.Linear(bottleneck, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False
            
    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = self.mlp(x)
        x = torch.nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        
        return x

class RARP_NVB_DINO_Wrapper(torch.nn.Module):
    def __init__(self, backbone:torch.nn.Module, new_head:torch.nn.Module, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.backbone = backbone
        self.head = new_head
        
    def forward(self, x):
        if isinstance(x, list):
            n_crops = len(x)
            concatCrops = torch.cat(x, dim=0)
        else:
            concatCrops = x
            n_crops = 1
        embedding = self.backbone(concatCrops)
        logitis = self.head(embedding)
        chunks = logitis.chunk(n_crops)
        
        return chunks
    
class RARP_NVB_DINO_Loss(torch.nn.Module):
    def __init__(self, out_dim:int, teacher_Thao:float = 0.04, student_Thao:float = 0.1, center_momentum:float = 0.9, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.S_Thao = student_Thao
        self.T_Thao = teacher_Thao
        self.C_Momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))
        
    def forward(self, s_Output, t_Output):
        sTemp = [s / self.S_Thao for s in s_Output]
        tTemp = [(t - self.center) / self.T_Thao for t in t_Output]
        
        studentSM = [torch.nn.functional.log_softmax(s, dim=-1) for s in sTemp]
        teacherSM = [torch.nn.functional.softmax(t, dim=-1).detach() for t in tTemp]
        
        total_loss = 0
        n_loss_terms = 0
        
        for t_ix, t in enumerate(teacherSM):
            for s_ix, s in enumerate(studentSM):
                if (t_ix == s_ix) and (len(teacherSM) > 1):
                    continue
                
                loss = torch.sum(-t * s, dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        
        total_loss /= n_loss_terms
        self.update_center(t_Output)
        
        return total_loss
    
    @torch.no_grad()
    def update_center(self, t_output):
        b = torch.cat(t_output).mean(dim=0, keepdim=True)
        self.center = self.center * self.C_Momentum + b * (1 - self.C_Momentum)
        
class RARP_NVB_DINO_RestNet50_Deep(L.LightningModule):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        TypeLoss=TypeLossFunction.CrossEntropy,
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        L1:float = None,
        L2:float = 0,
    ) -> None:
        super().__init__()
    
        self.lr =  lr
        self.Lambda_L1 = L1
        self.Lambda_L2 = L2
        self.threshold = threshold
        self.momentum_teacher = momentum_teacher
        self.out_dim = 512
        self.in_dim = 2048

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
        self.teacher_Labels = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
        self.student = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) #torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
        self.teacher_Features = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        
        self.student.fc = torch.nn.Identity()
        self.teacher_Features.fc = torch.nn.Identity()
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        for parms in self.teacher_Labels.model.parameters():
            parms.requires_grad = False
            
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False
        
        self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, 0.04, 0.1, momentum_teacher)
        self.lossFN_KD = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
        
        #self.lossFH_KLDiv = torch.nn.KLDivLoss(reduction="batchmean")
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(self.out_dim, 128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Linear(8, 1)
        )
        
    def forward(self, data, val_step:bool = True):
        if val_step:
            data = data.float()
            dataClassificator, dataTeacher, dataStudent = data, data, data
        else:
            data = [d.float() for d in data]
            dataClassificator, dataTeacher, dataStudent = data[0], data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        TeacherLabels = self.teacher_Labels(dataClassificator)
        Student = self.student(dataStudent)
        
        # es se evaluan todas las salidas del estuidaitne
        #if isinstance(dataStudent, list):
        #    #index = np.random.randint(0, len(dataStudent))
        #    temp = self.student(dataStudent)
        #    CatS_Classifier = torch.cat(temp, dim=0)
        #    meanS_Classifier = torch.zeros(self.in_dim)
        #    for dataS in temp:
        #        meanS_Classifier += dataS
        #    #S_Classifier = self.student(dataStudent[index])
        #    S_Classifier = meanS_Classifier / len(dataStudent)
        #else:
        #    S_Classifier = Student
            
        Cont_Net = torch.cat(Student, dim=0)
                
        pred = self.clasiffier(Cont_Net)
        
        if not val_step:
            TeacherLabels = [self.teacher_Labels(dataClassificator) for _ in range(len(dataStudent))]
            TeacherLabels = torch.cat(TeacherLabels, dim=0)
        
        TeacherLabelsPred = torch.sigmoid(TeacherLabels.flatten())
        PseudoLabels = (TeacherLabelsPred > self.threshold) * 1.0
                        
        return (pred.flatten(), PseudoLabels, TeacherLabels.flatten()), (TeacherDino, Student)
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        if not val_step:
            label = torch.cat([label for _ in range(len(img))], dim=0)
        
        label = label.float()
        KD_Prediction, DINO_Loss = self(img, val_step)
        TeacherF, StudentF = DINO_Loss
        prediction, PseudoLabels, teacherOutputs = KD_Prediction
        
        predicted_labels = torch.sigmoid(prediction)
        
        ##verstion 1
        W_Alpha, W_Beta = (1, 0.5)#(1, 0.5)
        loss = W_Alpha * self.lossFN_KD(prediction, PseudoLabels) + W_Beta * self.lossFN_KD(prediction, label)
        
        #version 2
        #thao_KD = 1#5.0
        #W_Alpha, W_Beta = (0.6, 0.4)
        
        #softTeacher = torch.sigmoid(teacherOutputs/thao_KD)
        #softStudent = torch.sigmoid(prediction/thao_KD)
        
        #loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
        #loss_hl = self.lossFN_KD(prediction, label)
        
        #loss = W_Alpha * loss_hl + W_Beta * loss_sl
        
        #loss = W_Alpha * self.lossFN_KD(prediction, label) + W_Beta * (self.lossFH_KLDiv(softStudent, softTeacher) * (thao_KD ** 2))
        loss += (self.lossFN_DINO(StudentF, TeacherF) if not val_step else 0)
        
        if not val_step:
            self.logger.experiment.add_histogram ("Teacher", TeacherF[0])
            self.logger.experiment.add_histogram ("Student", StudentF[1])
                
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.student.parameters(): # aqui
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
        
        #return loss, PseudoLabels, predicted_labels
        return loss, label, predicted_labels
        
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, False)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        with torch.no_grad():
            for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
                teacher_ps.data.mul_(self.momentum_teacher)
                teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
            
            self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
            
                
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr, weight_decay=self.Lambda_L2) 
        
        return [optimizer]

class RARP_NVB_DINO_VAN(RARP_NVB_DINO_RestNet50_Deep):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher: float = 0.9995, 
        lr: float = 0.0001, 
        L1: float = None, 
        L2: float = 0
    ) -> None:
        super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
        
        self.in_dim = 512
        
        self.student = van.van_b2(pretrained = True, num_classes = -1) 
        self.teacher_Features = van.van_b2(pretrained = True, num_classes = -1) 
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False

class RARP_NVB_DINO_ViT(RARP_NVB_DINO_RestNet50_Deep):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher: float = 0.9995, 
        lr: float = 0.0001, 
        L1: float = None, 
        L2: float = 0
    ) -> None:
        super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
        
        self.in_dim = 768
        
        self.student = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        self.teacher_Features = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        
        self.student.heads = torch.nn.Identity()
        self.teacher_Features.heads = torch.nn.Identity()
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False

class RARP_NVB_Classification_Head(torch.nn.Module):
    def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.activation = activation_fn
        
        if len (layer) == 0:        
            self.head = torch.nn.Linear(in_features, out_features)
        else:
            temp_head = []
            next_input = in_features
            for num in layer:
                temp_head.append(torch.nn.Linear(next_input, num))
                temp_head.append(self.activation)
                temp_head.append(torch.nn.Dropout(0.4))
                next_input = num
                
            temp_head[-1] = torch.nn.Dropout(0.2)
            temp_head.append(torch.nn.Linear(next_input, out_features))
            
            self.head = torch.nn.Sequential(*temp_head)
            del temp_head
    
    def forward(self, x):
        return self.head(x)

class RARP_NVB_DINO_MultiTask(L.LightningModule):
    # Define a hook function to capture the output
    def _hook_fn_Student(self, module, input, output):
        self.last_conv_output_S = output
        
    def _hook_fn_Teacher(self, module, input, output):
        self.last_conv_output_T = output
    
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy,
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        L1:float = None,
        L2:float = 0,
        std: float = None,
        mean: float = None,
        SoftAdptAlgo:int = 0,
        SoftAdptBeta:float = 0.1,
        Teacher_T:float = 0.04,
        Student_T:float = 0.1,
        intermittent:bool = False
    ) -> None:
        super().__init__()
        
        self.intermittent_train = intermittent
        
        self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
        self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
    
        self.lr =  lr
        self.Lambda_L1 = L1
        self.Lambda_L2 = L2
        self.Teacher_t = Teacher_T
        self.Studet_T = Student_T
        self.momentum_teacher = momentum_teacher
        self.out_dim = 1024
        self.in_dim = 512
        self.weights = torch.tensor([1,1,1])
        
        self.softAdapt = NormalizedSoftAdapt(SoftAdptBeta) if SoftAdptAlgo == 1 else LossWeightedSoftAdapt(SoftAdptBeta)
        self.loss_history = {
            'loss_DINO': [], 
            'loss_Reconstruction': [],
            'loss_Binary': [],
        }

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
        self.student = van.van_b2(pretrained = True, num_classes = 0)
        self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0)
        
        self.decoder = DynamicDecoder(input_channels=1024) 
               
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
        )
        
        self.teacher_Features.load_state_dict(self.student.state_dict())
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False
        
        self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
        self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
        self.ReconstructionLoss = torch.nn.MSELoss()
        
        self.last_conv_output_T = None
        self.last_conv_output_S = None
        
        self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher)
        self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(1024, 128),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.4),
            
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.2),
            
            torch.nn.Linear(8, 1)
        )
                
        print(f"lr={self.lr}, L1={self.Lambda_L1}")
        
    def _denormalize(self, tensor:torch.Tensor):
        # Move mean and std to the same device as the input tensor
        mean = self.mean_IMG.to(tensor.device)
        std = self.std_IMG.to(tensor.device)
        return tensor * std + mean
    
    def _calc_L1(self, params):
        l1 = 0
        for p in params:
            l1 += torch.sum(torch.abs(p))
        return self.Lambda_L1 * l1
    
    def _calc_weights(self, log_weights:bool = True):
        self.weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]),
            torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_Reconstruction"]),
            torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
            verbose=False
        )

        if log_weights:
            self.log("W_loss_img", self.weights[1], on_epoch=True, on_step=False)
            self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_GT", self.weights[2], on_epoch=True, on_step=False)

        self.loss_history = {
            'loss_DINO': [], 
            'loss_Reconstruction': [],
            'loss_Binary': [],
        }
        
    def forward(self, data, val_step:bool = True):
        if val_step:
            data = data.float()
            dataTeacher, dataStudent = data, data
        else:
            data = [d.float() for d in data]
            dataTeacher, dataStudent = data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        Student = self.student(dataStudent)
                       
        if not val_step:
            NumChunks = len(dataStudent)
            S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
            self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
        
        cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
        
        reconstructed_image = self.decoder(cat_features)
        
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
        pred = self.clasiffier(Cont_Net)
                        
        return pred, (Student, TeacherDino), reconstructed_image
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        prediction, features, new_image = self(img, val_step)
        StudentF, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, (NOAH, RARP_NVB_Classification_Head)):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_DINO"].append(loss_Dino.item())
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
        
        return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
        
    def on_train_epoch_start(self):
        if self.current_epoch % 2 == 0 and self.current_epoch != 0:
            self._calc_weights()
            
        if self.intermittent_train and self.current_epoch != 0:
            
            par_epoch = (self.current_epoch % 2 == 0)
            
            for parms in self.student.backbone.parameters():
                parms.requires_grad = par_epoch
                
            for parms in self.decoder.parameters():
                parms.requires_grad = not par_epoch
                
            for parms in self.clasiffier.parameters():
                parms.requires_grad = not par_epoch
                
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch, False)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("train_loss_DINO", losses[0], on_epoch=True, on_step=False)
        self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
        
        if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)

        return loss
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        with torch.no_grad():
            for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
                teacher_ps.data.mul_(self.momentum_teacher)
                teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
            
            #self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
    
    def on_after_backward(self):
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        self.log("grad_norm", total_norm)
        
        if total_norm < 1e-8:
            self.log("grad_warning", "Vanishing gradient suspected!")
                
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("val_loss_DINO", losses[0], on_epoch=True, on_step=False)
        self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, losses = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
        if self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)  #, weight_decay=self.Lambda_L2
        
        return [optimizer]

#Ablation Models
"""T-S Multi-task model With out Recostruccion (V3R1_A1)

Returns:
    LightningModule
"""
class RARP_NVB_DINO_MultiTask_A1(RARP_NVB_DINO_MultiTask):
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
        
        self.decoder = torch.nn.Identity()
        del self.ReconstructionLoss
        
        self.mean_IMG = None
        self.std_IMG = None
        
        self.weights = torch.tensor([1,1])
        self.loss_history = {
            'loss_DINO': [], 
            'loss_Binary': [],
        }
    
    def _calc_weights(self, log_weights:bool = True):
        self.weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]),
            torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
            verbose=False
        )

        if log_weights:
            self.log("W_loss_img", 0, on_epoch=True, on_step=False)
            self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False)

        self.loss_history = {
            'loss_DINO': [], 
            'loss_Binary': [],
        }
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        prediction, features, new_image = self(img, val_step)
        StudentF, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, NOAH):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        #orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        #loss_img = self.ReconstructionLoss(new_image, orignalImg)
        #loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_DINO"].append(loss_Dino.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_Dino + self.weights[1] * loss_HL
        
        return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[1] * loss_HL, 0, new_image)
        
"""T-S Multi-task model With out SoftAdadapt, Fix loss wegth 0.333_  (V3R1_A2)

Returns:
    LightningModule
"""
class RARP_NVB_DINO_MultiTask_A2(RARP_NVB_DINO_MultiTask):
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
        
        del self.softAdapt         
        self.weights = [1/3, 1/3, 1/3]
        
    def _calc_weights(self):
        self.weights = [1/3, 1/3, 1/3]

        self.loss_history = {
            'loss_DINO': [], 
            'loss_Reconstruction': [],
            'loss_Binary': [],
        }

"""S Multi-task model No Dino base encoder VAN_b2, (V3R1_A3_1)

Returns:
    LightningModule
"""
class RARP_NVB_DINO_MultiTask_A3(RARP_NVB_DINO_MultiTask):
    def _hook_fn_Student(self, module, input, output):
        self.last_conv_output_S = output
        self.last_conv_output_T = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
        if not self.val_phace:
            self.last_conv_output_T = self.last_conv_output_T[:16] # Fixed bach of 8
        
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
               
        self.student = RARP_NVB_DINO_Wrapper(
            van.van_b2(pretrained = True, num_classes = 0),
            torch.nn.Identity()
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            torch.nn.Identity(),
            torch.nn.Identity()
        )
        
        self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
        
        self.val_phace = True
                
        self.weights = torch.tensor([1,1])
        self.loss_history = {
            'loss_Reconstruction': [], 
            'loss_Binary': [],
        }
    
    def _calc_weights(self, log_weights:bool = True):
        self.weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_DINO"]),
            torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
            verbose=False
        )

        if log_weights:
            self.log("W_loss_img", self.weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_DINO", 0, on_epoch=True, on_step=False)
            self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False)

        self.loss_history = {
            'loss_Reconstruction': [], 
            'loss_Binary': [],
        }
        
    def _shared_step(self, batch, val_step:bool = False):
        self.val_phace = val_step
        
        img, label = batch
                
        prediction, features, new_image = self(img, val_step)
        _, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, NOAH):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        #loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_img + self.weights[1] * loss_HL
        
        return loss, label, predicted_labels, (0, self.weights[1] * loss_HL, self.weights[0] * loss_img, new_image)
        
"""S Multi-task model No Dino base encoder RN50, (V3R1_A3_2)

Returns:
    LightningModule
"""
class RARP_NVB_DINO_MultiTask_A3_RN50(RARP_NVB_DINO_MultiTask):
    def _hook_fn_Student(self, module, input, output):
        self.last_conv_output_S = output
        self.last_conv_output_T = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
        if not self.val_phace:
            self.last_conv_output_T = self.last_conv_output_T[:16] # Fixed bach of 8
            
        self.last_conv_output_T = self.last_conv_output_T[:, :0, :, :]
        
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
               
        self.student = RARP_NVB_DINO_Wrapper(
            torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT),
            torch.nn.Identity()
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            torch.nn.Identity(),
            torch.nn.Identity()
        )
        
        self.student.backbone.layer4.register_forward_hook(self._hook_fn_Student)
        
        self.val_phace = True
                
        self.weights = torch.tensor([1,1])
        self.loss_history = {
            'loss_Reconstruction': [], 
            'loss_Binary': [],
        }
        
        self.decoder = DynamicDecoder(input_channels=2048)
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(2048, 128),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.4),
            
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.2),
            
            torch.nn.Linear(8, 1)
        )
    
    def _calc_weights(self, log_weights:bool = True):
        self.weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_DINO"]),
            torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
            verbose=False
        )

        if log_weights:
            self.log("W_loss_img", self.weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_DINO", 0, on_epoch=True, on_step=False)
            self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False)

        self.loss_history = {
            'loss_Reconstruction': [], 
            'loss_Binary': [],
        }
        
    def _shared_step(self, batch, val_step:bool = False):
        self.val_phace = val_step
        
        img, label = batch
                
        prediction, features, new_image = self(img, val_step)
        _, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, NOAH):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        #loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_img + self.weights[1] * loss_HL
        
        return loss, label, predicted_labels, (0, self.weights[1] * loss_HL, self.weights[0] * loss_img, new_image)

"""T-S Multi-task model, classification head layer change, (V3R1_A4)

Returns:
    LightningModule
"""       
class RARP_NVB_DINO_MultiTask_A4(RARP_NVB_DINO_MultiTask):
    def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
        #layeres = [128, 8] #L3 original
        #layeres = [256, 128, 8] #L4
        #layeres = [8] #L2
        layeres = [] #L1
        self.clasiffier = RARP_NVB_Classification_Head(1024, 1, layeres, torch.nn.SiLU(True))
        
#end Ablation Models
class RARP_NVB_DINO_MultiTask_Unet(RARP_NVB_DINO_MultiTask):
    def _encoder_hool_fn(self, module, input, output):
        self.feature_maps.append(output)
    
    def _register_encoder_hooks(self):
        for layer in self.list_blocks:
            self.hooks.append(layer.register_forward_hook(self._encoder_hool_fn))
    
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy,
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        L1:float = None,
        L2:float = 0,
        std: float = None,
        mean: float = None,
        SoftAdptAlgo:int = 0,
        SoftAdptBeta:float = 0.1
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
        
        self.hooks = []
        self.feature_maps = []
        
        self.list_blocks = [
            self.student.backbone.block1[-1],
            self.student.backbone.block2[-1],
            self.student.backbone.block3[-1],
            self.student.backbone.block4[-2],
        ]
        
        self._register_encoder_hooks()
        
        self.decoder = DecoderUnet(1024)
                
    def forward(self, data, val_step:bool = True):
        self.feature_maps = []
        
        if val_step:
            data = data.float()
            dataTeacher, dataStudent = data, data
        else:
            data = [d.float() for d in data]
            dataTeacher, dataStudent = data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        Student = self.student(dataStudent)
                       
        if not val_step:
            NumChunks = len(dataStudent)
            S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
            self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
        
        cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
        
        if not val_step:
            NumChunks = len(dataStudent)
            temp = []
            for f_maps in self.feature_maps:
                temp.append(torch.cat(f_maps.chunk(NumChunks)[1:3], dim=0))
            self.feature_maps = temp
            del temp
        
        self.feature_maps.append(cat_features)
        reconstructed_image = self.decoder(self.feature_maps)
        
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
        pred = self.clasiffier(Cont_Net)
                        
        return pred, (Student, TeacherDino), reconstructed_image

class RARP_NVB_DINO_MultiTask_MultiLabel(RARP_NVB_DINO_MultiTask):
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.BCEWithLogits, 
        momentum_teacher = 0.9995, 
        lr = 0.0001, 
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None, 
        SoftAdptAlgo = 0, 
        SoftAdptBeta = 0.1,
        Num_Lables = 2
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
        
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(1024, 128),
            torch.nn.Dropout(0.4),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(128, 8),
            torch.nn.Dropout(0.2),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(8, Num_Lables)
        )
        
        self.train_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
        self.val_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
        self.test_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
        self.f1ScoreTest = torchmetrics.F1Score('multilabel', num_labels=Num_Lables)

class RARP_NVB_DINO_MultiTask_v2(RARP_NVB_DINO_MultiTask):
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher = 0.9995, 
        lr = 0.0001, 
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean)
            
        self.train_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.val_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.test_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.f1ScoreTest = torchmetrics.F1Score('multiclass', num_classes=4)
        
        self.lossFN = torch.nn.CrossEntropyLoss()
        
        self.softAdapt =  LossWeightedSoftAdapt(0.1) #NormalizedSoftAdapt(0.1)
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(1024, 512),
            torch.nn.Dropout(0.2),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(512, 256),
            #torch.nn.Dropout(0.4),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(256, 128),
            #torch.nn.Dropout(0.4),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(128, 8),
            #torch.nn.Dropout(0.2),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(8, 4)
        )
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        prediction, features, new_image = self(img, val_step)
        StudentF, TeacherF = features
        
        predicted_labels = torch.softmax(prediction,dim=1)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label for _ in range(len(TeacherF))], dim=0) if not val_step else label
        
        #DINO Loss
        loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(predicted_labels, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_l1 = 0
                for params in self.clasiffier.parameters(): # aqui
                    loss_l1 += torch.sum(torch.abs(params))
                loss_HL += self.Lambda_L1 * loss_l1
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_DINO"].append(loss_Dino.item())
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
        
        return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)

class RARP_NVB_RN50_VAN_V2 (RARP_NVB_Model):
    # Define a hook function to capture the output
    def _hook_fn(self, module, input, output):
        self.last_conv_output = output
        
    def __init__(
        self, 
        PseudoEstimator:str = None, 
        threshold:float = 0.5, 
        InitWeight=None, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        PseudoLables:bool=True,
        HParameter = {},
        std: float = None,
        mean: float = None,
        **kwargs
    ) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
        self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
        
        #self.RARP_RestNet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50') 
        self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
        self.RARP_RestNet50.model.fc = torch.nn.Identity()  
        for parms in self.RARP_RestNet50.model.parameters():
            parms.requires_grad = False      
            
        self.RARP_RestNet50 = torch.nn.Sequential(*list(self.RARP_RestNet50.model.children())[:-2])
          
        self.decoder = Decoder()
        
        self.FeaturesLoss = FeatureAlignmentLoss()
        self.ReconstructionLoss = torch.nn.MSELoss()
    
        match (TypeError):
            case TypeLossFunction.ContrastiveLoss:
                self.lossFN = ContrastiveLoss()
            case _:
                pass            
        
        self.model = van.van_b2(pretrained = True, num_classes = 1)  
              
        self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04) 
        self.lr = getNVL(HParameter, "lr", 1.0E-4)
                
        print(f"lr={self.lr}, L1={self.Lambda_L1}")
        
        # Initialize a variable to store the output
        self.last_conv_output = None
        self.model.block4[-1].mlp.dwconv.register_forward_hook(self._hook_fn)
        

    def forward(self, data):
        dataTeacher, dataStudent, _ = data
        dataStudent = dataStudent.float()
        dataTeacher = dataTeacher.float()
        
        feature_Teacher = self.RARP_RestNet50 (dataTeacher)
        #feature_Student = self.model.forward_features(dataStudent)
        pred = self.model(dataStudent)  
        feature_Student = self.last_conv_output
        
        #feature_Teacher = torch.nn.functional.adaptive_avg_pool1d(feature_Teacher, feature_Student.size(-1)) #Re-size output vector to macth VAN 
        
        cat_features = (feature_Student + feature_Teacher) / 2
                        
        reconstructed_image = self.decoder(cat_features) 
        
        feature_Student = torch.nn.functional.adaptive_avg_pool2d(feature_Student, (1, 1)).flatten(1)
        feature_Teacher = torch.nn.functional.adaptive_avg_pool2d(feature_Teacher, (1, 1)).flatten(1)
                
        return pred, (feature_Student, feature_Teacher), reconstructed_image
    
    def _shared_step(self, batch):
        img, label = batch
        label = label.float()
        orignalImg = img[2].float()
                
        prediction, features, new_image = self(img)
        
        prediction = prediction.flatten()
        predicted_labels = torch.sigmoid(prediction)
        
        #cosine similarity
        loss_cosine = self.FeaturesLoss(features[0], features[1])
        #Clasificator
        loss_hl = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()
        
        loss = loss_cosine + loss_hl + loss_img
        
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.model.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
            
        return loss, label, predicted_labels, (loss_cosine, loss_hl, loss_img, new_image)
    
    def _denormalize(self, tensor:torch.Tensor):
        # Move mean and std to the same device as the input tensor
        mean = self.mean_IMG.to(tensor.device)
        std = self.std_IMG.to(tensor.device)
        return tensor * std + mean
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("train_loss_cosSim", losses[0], on_epoch=True, on_step=False)
        self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
        
        if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("val_loss_cosSim", losses[0], on_epoch=True, on_step=False)
        self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
        if self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
        
        return optimizer
                
class RARP_NVB_ResNet50_VAN(RARP_NVB_Model):
    def __init__(
        self, 
        PseudoEstimator:str = None, 
        threshold:float = 0.5, 
        InitWeight=None, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        PseudoLables:bool=True,
        HParameter = {},
        **kwargs
    ) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
        self.threshold = threshold
        self.PseudoLables = PseudoLables
        
        match (TypeError):
            case TypeLossFunction.ContrastiveLoss:
                self.lossFN = ContrastiveLoss()
            case _:
                pass            
            
        for parms in self.RARP_RestNet50.model.parameters():
            parms.requires_grad = False
            
        self.model = van.van_b2(pretrained = True, num_classes = 1)        
        self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04) 
        self.lr = getNVL(HParameter, "lr", 1.0E-4)
        self.W_Apha = getNVL(HParameter, "Alpha", 0.60)
        self.W_Beta = 1 - getNVL(HParameter, "Alpha", 0.60)
        self.thao_KD = getNVL(HParameter, "Thao", 5)
        
        print(f"A={self.W_Apha}; B={self.W_Beta}, T={self.thao_KD}, lr={self.lr}, L1={self.Lambda_L1}")
        
        #self.Lambda_L1 = None
        #self.scheduler = True
        #self.lr = 1.74E-2
        
    def forward(self, data):
        #dataTeacher, dataStudent = data
        if isinstance(data, tuple):
            dataTeacher, dataStudent = data
        elif isinstance(data, torch.Tensor):
            dataTeacher, dataStudent = (data, data)
        RN50Pred = torch.sigmoid(self.RARP_RestNet50(dataTeacher).flatten())
        PseudoLabels = (RN50Pred > self.threshold) * 1.0
        
        dataStudent = dataStudent.float()
        pred = self.model(dataStudent)
                
        return pred, PseudoLabels, RN50Pred
        
    def _shared_step(self, batch):
        img, label = batch
        if self.InitWeight is not None:
            self.lossFN.weight = self.InitWeight[label]
        
        label = label.float()
        prediction, PseudoLabels, predictionRN50 = self(img) #.flatten()
        prediction = prediction.flatten()
        predicted_labels = torch.sigmoid(prediction)
         #
        # L[B] => BCEWithLogits
        # L[C] => ContrastiveLoss
        #Loss L[1] = L[B_HL] + L[B_PL]
        #Loss L[2] = L[C_HL]
        #Loss L[3] = L[C_HL] + L[C_PL]
        #Loss L[4] = L[C_y_hat]
        #Loss L[5] = L[KLD_PL] + L[B_PL]
        #Loss L[6] = L[MSE_PL] + L[B_PL]
        #Loss L[7] = L[B_HL] + L[BCE_PL]
        #Loss L[8] = FocalLoss*L[JSD]
        #Nuveo
        #thao_KD = 5.0
        #w_alpha, w_beta = (0.60, 0.40)
        
        #L[7]:
        softTeacher = torch.sigmoid(predictionRN50/self.thao_KD)
        softStudent = torch.sigmoid(prediction/self.thao_KD)
        
        loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
        loss_hl = self.lossFN(prediction, label)
        
        loss = self.W_Apha * loss_hl + self.W_Beta * loss_sl #/ (self.thao_KD ** 2)
        
        #loss = w_alpha * self.lossFN(prediction, label) + w_beta * self.lossFN(prediction, PseudoLabels) #L[1]
        #loss = w_alpha * (torch.nn.KLDivLoss()(softStudent, softTeacher) * (thao_KD ** 2)) + w_beta * self.lossFN(prediction, PseudoLabels) #L[5]
        #loss = w_alpha * torch.nn.functional.mse_loss(softStudent, softTeacher) + w_beta * self.lossFN(prediction, PseudoLabels) #L[6]
        #loss = self.lossFN(predictionRN50, predicted_labels, label) #L[2]
        #loss = self.lossFN(predictionRN50, predicted_labels, label) + self.lossFN(predictionRN50, predicted_labels, PseudoLabels) #L[3]
        #y_hat = (PseudoLabels != label) * 1
        #loss = self.lossFN(predictionRN50, predicted_labels, y_hat) #L[4]
        
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.model.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
        
        return loss, (PseudoLabels if self.PseudoLables else label), predicted_labels, (loss_hl, loss_sl)
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_lh", losses[0], on_step=False, on_epoch=True)
        self.log("train_ld", losses[1], on_step=False, on_epoch=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_lh", losses[0], on_step=False, on_epoch=True)
        self.log("val_ld", losses[1], on_step=False, on_epoch=True)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, _ = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
        if self.scheduler:
            scheduler = CosineAnnealingLR(
                optimizer, 
                max_epochs=150,
                warmup_epochs=8,
                warmup_start_lr=0.001,
                eta_min=0.0001
            )
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        return optimizer

class RARP_NVB_SSL_RestNet50_Deep(RARP_NVB_ResNet50_VAN):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        InitWeight=None, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        PseudoLables: bool = True,
        **kwargs
    ) -> None:
        super().__init__(None, threshold, InitWeight, TypeLoss, PseudoLables, **kwargs)
        
        self.RARP_RestNet50 = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        #self.model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
        tempFC_ft = 2048
        self.model.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features=tempFC_ft, out_features=128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Linear(8, 1)
        )
        
        for parms in self.RARP_RestNet50.model.parameters():
            parms.requires_grad = False

class RARP_NVB_MobileNetV2(RARP_NVB_Model):#class RARP_NVB_MobileNetV2(RARP_NVB_Model_BCEWithLogitsLoss):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

class RARP_NVB_EfficientNetV2(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
class RARP_NVB_Vit_b_16(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
        super().__init__(InitWeight, TypeLoss)
        
        self.model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        tempFC_ft = self.model.heads[0].in_features 
        self.model.heads[0] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
class RARP_NVB_DenseNet169(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
        super().__init__(InitWeight, TypeLoss)
        
        self.model = torchvision.models.densenet169(weights=torchvision.models.DenseNet169_Weights.DEFAULT)
        inFeatures = self.model.classifier.in_features
        self.model.classifier = torch.nn.Linear(in_features=inFeatures, out_features=1)

class RARP_NVB_RestNet50_old(L.LightningModule):
    def __init__(self, InitialWeigth = 1) -> None:
        super().__init__()

        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4) 
        self.lossFN = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([InitialWeigth]))

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')

    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def training_step(self, batch, batch_idx):
        img, label = batch
        label = label.float()
        prediction = self(img)[:,0]
        loss = self.lossFN(prediction, label)

        self.log("Train Loss", loss)
        self.log("Step Train Acc", self.train_acc(torch.sigmoid(prediction), label.int()))

        return loss
    
    def on_train_epoch_end(self):
        self.log("Train Acc", self.train_acc.compute())

    def validation_step(self, batch, batch_idx):
        img, label = batch
        label = label.float()
        prediction = self(img)[:,0]
        loss = self.lossFN(prediction, label)

        self.log("Val Loss", loss)
        self.log("Step Val Acc", self.val_acc(torch.sigmoid(prediction), label.int())) #train_acc

        return loss
    
    def on_validation_epoch_end(self):
        self.log("Val Acc", self.val_acc.compute())

    def configure_optimizers(self):
        return [self.optimizer]