Newer
Older
RARP_server / Models.py
import warnings
warnings.simplefilter("ignore")

import math
from typing import Any, Union
import torch
import torch.utils.checkpoint as torch_ckp
import torchvision
import torchmetrics
import torchmetrics.classification
import lightning as L
from enum import Enum
import timm
import van
import numpy as np
from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt
from noah import NOAH
import piq
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import pandas as pd
from pathlib import Path
from I3D_RestNet50 import I3DResNet50


def js_divergence_sigmoid(p_logits, q_logits):
    p_probs = torch.sigmoid(p_logits)
    q_probs = torch.sigmoid(q_logits)
    
    m = 0.5 * (p_probs + q_probs)
    
    bce_p_m = torch.nn.functional.binary_cross_entropy(m, p_probs, reduction='none')
    bce_q_m = torch.nn.functional.binary_cross_entropy(m, q_probs, reduction='none')
    
    js_div = 0.5 * (bce_p_m + bce_q_m)
    
    return js_div

def getNVL(obj:dict, key, default):
    return default if obj.get(key) is None else obj.get(key)

class Decoder (torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
        
        self.Activation = torch.nn.ReLU#torch.nn.GELU
        self.input_channels = input_channels
        
        blocks = []
        
        inCh = input_channels
        
        blocks.append(torch.nn.Conv2d(inCh, inCh, kernel_size=3, stride=1, padding=1, bias=False))
        blocks.append(torch.nn.BatchNorm2d(inCh))
        blocks.append(self.Activation())
        
        for i, outCh in enumerate(hidden_channels):
            blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=4, stride=2, padding=1, bias=False))
            blocks.append(torch.nn.BatchNorm2d(outCh))
            blocks.append(self.Activation())
            blocks.append(torch.nn.Conv2d(outCh, outCh, kernel_size=3, stride=1, padding=1, bias=False))
            blocks.append(torch.nn.BatchNorm2d(outCh))
            blocks.append(self.Activation())
            
            #blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=3, stride=2, padding=1, output_padding=1))
            #blocks.append(torch.nn.BatchNorm2d(outCh))
            #blocks.append(self.Activation())
            inCh = outCh
        
        blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=4, stride=2, padding=1, bias=False))
        blocks.append(torch.nn.BatchNorm2d(output_channels))
        blocks.append(self.Activation())
        blocks.append(torch.nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=False))
        blocks.append(torch.nn.BatchNorm2d(output_channels))
        blocks.append(self.Activation())
        
        #blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
                
        self.decoder = torch.nn.Sequential(*blocks)
        
    def forward(self, x):
        
        x = self.decoder(x)
        return x
    
class DynamicDecoder(torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64]):
        super(DynamicDecoder, self).__init__()

        # Ensure the number of hidden channels matches the number of blocks
        assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."

        layers = []
        in_channels = input_channels
        
        # Loop to create the decoder blocks
        for out_channels in hidden_channels:
            layers.append(torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(torch.nn.BatchNorm2d(out_channels))
            layers.append(torch.nn.ReLU(inplace=True))
            in_channels = out_channels

        # Final layer to get the output image
        layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
        #layers.append(torch.nn.Sigmoid())  # To get pixel values between 0 and 1

        # Combine all layers into a Sequential module
        self.decoder = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)    
    
class DecoderUnet(torch.nn.Module):
    def __init__(self, input_channels=2048, output_channels=3):
        super().__init__()

        self.dropout = torch.nn.Dropout2d(0.4)

        self.upConv_0 = torch.nn.ConvTranspose2d(input_channels, 512, kernel_size=2, stride=2)
        self.decoder_0 = self._conv_block(1024, 512) 
        
        self.upConv_1 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2)
        self.decoder_1 = self._conv_block(640, 320)
        
        self.upConv_2 = torch.nn.ConvTranspose2d(320, 128, kernel_size=2, stride=2)
        self.decoder_2 = self._conv_block(256, 128)
        
        self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder_3 = self._conv_block(128, 64) 
        
        self.upConv_4 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder_4 = self._conv_block(32, 16)
        
        self.enc3_upsample = torch.nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
        self.enc2_upsample = torch.nn.ConvTranspose2d(320, 320, kernel_size=2, stride=2)
        self.enc1_upsample = torch.nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
        self.enc0_upsample = torch.nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        
        self.last_conv = torch.nn.Conv2d(16, output_channels, kernel_size=1)
        
    def _conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
        )

    def forward(self, x):
        encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = x
                
        decoder_l3 = self.upConv_0(btlneck)
        encoder_l3 = self.enc3_upsample(encoder_l3)
        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
        decoder_l3 = self.decoder_0(decoder_l3)
        
        decoder_l3 = self.dropout(decoder_l3)
        
        decoder_l2 = self.upConv_1(decoder_l3)
        encoder_l2 = self.enc2_upsample(encoder_l2)
        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
        decoder_l2 = self.decoder_1(decoder_l2)
        
        decoder_l2 = self.dropout(decoder_l2)
        
        decoder_l1 = self.upConv_2(decoder_l2)
        encoder_l1 = self.enc1_upsample(encoder_l1)
        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
        decoder_l1 = self.decoder_2(decoder_l1)
        
        decoder_l1 = self.dropout(decoder_l1)
        
        decoder_l0 = self.upConv_3(decoder_l1)
        encoder_l0 = self.enc0_upsample(encoder_l0)
        decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1)
        decoder_l0 = self.decoder_3(decoder_l0)
        
        decoder_last = self.upConv_4(decoder_l0)
        decoder_last = self.decoder_4(decoder_last)
        
        return self.last_conv(decoder_last)           

class ModelsList(Enum):
    RestNet50 = 1
    DenseNet169 = 2
    Efficientnet_b0 = 3
    MobileNetV2 = 4
    Inception3 = 5
    ResNeXt_50_32x4d = 6
    RestNet50_Droput = 7

class TypeLossFunction(Enum):
    CrossEntropy = 0
    BCEWithLogits = 1
    HingeLoss = 2
    FocalLoss = 3
    ContrastiveLoss = 4
       
class FeatureAlignmentLoss(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def forward(self, F_student, F_teacher):
        fS = torch.nn.functional.normalize(F_student, p=2, dim=-1)
        fT = torch.nn.functional.normalize(F_teacher, p=2, dim=-1)
        
        cos_Sim = torch.sum(fS * fT, dim=-1)
        
        loss = 1 - cos_Sim
        
        return loss.mean()
    
class CosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_epochs: int,
        max_epochs: int,
        warmup_start_lr: float = 0.00001,
        eta_min: float = 0.00001,
        last_epoch: int = -1,
    ):
        """
        Args:
            optimizer (torch.optim.Optimizer):
                最適化手法インスタンス
            warmup_epochs (int):
                linear warmupを行うepoch数
            max_epochs (int):
                cosine曲線の終了に用いる 学習のepoch数
            warmup_start_lr (float):
                linear warmup 0 epoch目の学習率
            eta_min (float):
                cosine曲線の下限
            last_epoch (int):
                cosine曲線の位相オフセット
        学習率をmax_epochsに至るまでコサイン曲線に沿ってスケジュールする
        epoch 0からwarmup_epochsまでの学習曲線は線形warmupがかかる
        https://pytorch-lightning-bolts.readthedocs.io/en/stable/schedulers/warmup_cosine_annealing.html
        """
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.warmup_start_lr = warmup_start_lr
        self.eta_min = eta_min
        super().__init__(optimizer, last_epoch, verbose=False)
        return None

    def get_lr(self):
        if self.last_epoch == 0:
            return [self.warmup_start_lr] * len(self.base_lrs)
        if self.last_epoch < self.warmup_epochs:
            return [
                group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]
        if self.last_epoch == self.warmup_epochs:
            return self.base_lrs
        if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
            return [
                group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
                for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
            ]

        return [
            (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
            / (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)))
            * (group["lr"] - self.eta_min)
            + self.eta_min
            for group in self.optimizer.param_groups
        ]    

class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=1.0, distance:int = 0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.TypeDistance = distance

    def forward(self, output1, output2, label):
        # Calcula la distancia euclidiana entre las dos salidas
        Distance = torch.nn.functional.pairwise_distance(output1, output2)
        
        # Calcula la pérdida contrastiva
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(Distance, 2) +
            (label) * torch.pow(torch.clamp(self.margin - Distance, min=0.0), 2)
        )
        
        return loss_contrastive    

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook()
        
    def hook(self):
        def forward_hook(module, Input, Output):
            self.activations = Output
        
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]
            
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)
        
    def generate_cam(self, data:torch.Tensor, target_class):
        self.model.zero_grad()
        data.requires_grad= True
        output = self.model(data).flatten()
        loss = torch.nn.functional.binary_cross_entropy_with_logits(output, target_class)
        loss.backward()
        
        pooled_Grad = torch.mean(self.gradients, dim=[0, 2, 3])
        for i in range (pooled_Grad.size(0)):
            self.activations[:, i, :, :] *= pooled_Grad[i]
            
        cam = torch.mean(self.activations, dim=1).squeeze()
        cam = np.maximum(cam.detach().numpy(), 0)
        cam = (cam - cam.min()) / (cam.max() - cam.min())
        return cam

class UNet_RN18(torch.nn.Module):
    def _conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.SiLU(inplace=True),
        )
    
    def _hook_fn(self, module, input, output):
        self.feature_maps.append(output)    
    
    def _register_encoder_hooks(self):
        for layer in self.list_blocks:
            self.hooks.append(layer.register_forward_hook(self._hook_fn))
    
    def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.hooks = []
        self.feature_maps = []
        
        self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        self.encoder_base.fc = torch.nn.Identity()
        
        #for parms in self.encoder_base.parameters():
        #    parms.requires_grad = False
        
        self.list_blocks = [
            self.encoder_base.conv1,
            self.encoder_base.layer1,
            self.encoder_base.layer2,
            self.encoder_base.layer3,
            self.encoder_base.layer4
        ]
        
        self._register_encoder_hooks()
        
        self.dropout = torch.nn.Dropout2d(0.4)
        
        self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        
        self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)

        self.decoder_0 = self._conv_block(512, 256)     #0
        self.decoder_1 = self._conv_block(256, 128)     #1
        self.decoder_2 = self._conv_block(128, 64)      #2
        self.decoder_3 = self._conv_block(32 + 64, 32)  #3
        
        self.decoder_extra = self._conv_block(16, 16)
        
        self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1)
        
    def forward(self, x):
        self.feature_maps = []
        
        _ = self.encoder_base(x) # forward to encoder and call hooks 
        
        encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps
                
        decoder_l3 = self.upConv_0(btlneck) 
        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) 
        decoder_l3 = self.decoder_0(decoder_l3)    
        
        decoder_l3 = self.dropout(decoder_l3)
        
        decoder_l2 = self.upConv_1(decoder_l3) 
        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) 
        decoder_l2 = self.decoder_1(decoder_l2)  
        
        decoder_l2 = self.dropout(decoder_l2)
        
        decoder_l1 = self.upConv_2(decoder_l2) 
        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) 
        decoder_l1 = self.decoder_2(decoder_l1)  
        
        decoder_l1 = self.dropout(decoder_l1)
        
        decoder_l0 = self.upConv_3(decoder_l1) 
        decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1) 
        decoder_l0 = self.decoder_3(decoder_l0)
        
        decoder_last = self.upConv_extra(decoder_l0)
        decoder_last = self.decoder_extra(decoder_last)
                
        return self.last_conv(decoder_last)
    
class RARP_NVB_ROI_Mask_Unet(L.LightningModule):
    def __init__(self,*args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.model = UNet_RN18(in_channels=3, out_channels=1)
               
        self.lr = 1E-4
        self.Lambda_L1 = None
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.train_IoU = torchmetrics.classification.BinaryJaccardIndex()
        self.val_IoU = torchmetrics.classification.BinaryJaccardIndex()
                
    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch, val_step:bool = True):
        img, mask = batch
        
        mask = mask.float()
        mask = mask.unsqueeze(1)
        prediction = self(img)
                
        loss = self.lossFN(prediction, mask)
        
        predicted_labels = torch.sigmoid(prediction)
        
        if not val_step:
            if self.Lambda_L1 is not None:
                loss_l1 = 0
                for name, params in self.model.named_parameters():
                    if "decoder" in name or "upConv" in name: 
                        loss_l1 += torch.norm(params, p=1)
                    
                loss += self.Lambda_L1 * loss_l1
            
        return loss, mask, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, False)

        self.train_IoU.update(predicted_labels, true_labels)
        
        self.log("train_loss", loss, on_epoch=True)
        self.log("train_acc_IoU", self.train_IoU, on_epoch=True, on_step=False)

        return loss
    
    def on_after_backward(self):
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        self.log("grad_norm", total_norm)
        
        if total_norm < 1e-8:
            self.log("grad_warning", "Vanishing gradient suspected!")
    
    def on_train_epoch_start(self):
        for parms in self.model.encoder_base.parameters():
            parms.requires_grad = (self.current_epoch % 2 == 0)
            
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        
        self.val_IoU.update(predicted_labels, true_labels)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_acc_IoU", self.val_IoU, on_epoch=True, on_step=False)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) 
                
        return [optimizer]

class RARP_NVB_ResNet50_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.resnet50()
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) 
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        pred = self.model.fc(Cont_Net)
        
        return pred, featureMap

class RARP_NVB_VAN_CAM(L.LightningModule): #TODO
    def __init__(self) -> None:
        super().__init__()
        
        self.model = van.van_b2(pretrained = True)
        tempFC_ft = self.model.head.in_features
        self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
        
    def forward(self, data, label:torch.Tensor):
        
        cams = GradCAM(self.model, self.model.block4[-1])
        featureMap = cams.generate_cam(data, label)
              
        pred = self.model(featureMap)
        
        
        return pred, featureMap
    
class RARP_NVB_ResNet18_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.resnet18()
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
        self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) 
        Cont_Net = torch.flatten(Cont_Net)
        
        pred = self.model.fc(Cont_Net)
        
        return pred, featureMap
    
class RARP_NVB_MobileNetV2_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.mobilenet_v2()
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

        self.feature_map = self.model.features
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        pred = self.model.classifier(Cont_Net)
        
        return pred, featureMap
    
class RARP_NVB_EfficientNetV2_CAM(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        
        self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

        self.feature_map = self.model.features
        
    def forward(self, data):
        featureMap = self.feature_map(data)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        pred = self.model.classifier(Cont_Net)
        
        return pred, featureMap

class RARP_NVB_Model_BCEWithLogitsLoss(L.LightningModule):
    def __init__(self, x=None) -> None:
        super().__init__()
        
        self.model = None
        
        self.lossFN = torch.nn.BCEWithLogitsLoss() #pos_weight=torch.tensor([2.73])
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1Score = torchmetrics.F1Score('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
    def forward(self, data):
        data = data.float()
        return self.model(data)
    
    def _shared_step(self, batch):
        img, label = batch
        label = label.float()
        pred = self.forward(img).flatten()
        loss = self.lossFN(pred, label)
        
        predicted_labels = torch.sigmoid(pred)
        
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        
        self.log("train_loss", loss)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss)
        self.val_acc(predicted_labels, true_labels)
        self.f1Score(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc(predicted_labels, true_labels)
        self.f1ScoreTest(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4) 
        return [optimizer]

class RARP_NVB_FOCAL_loss(torch.nn.Module):
    def __init__(self, alpha: float = 0.25, gamma: float = 2, reduction: str = "mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction 
        
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return torchvision.ops.focal_loss.sigmoid_focal_loss(input, target, self.alpha, self.gamma, self.reduction)

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Convert targets to one-hot encoding
        targets = torch.nn.functional.one_hot(targets, num_classes=inputs.size(1)).float()

        # Compute softmax over the inputs
        probs = torch.nn.functional.softmax(inputs, dim=1)
        log_probs = torch.nn.functional.log_softmax(inputs, dim=1)

        # Compute the focal loss components
        focal_weight = (1 - probs) ** self.gamma
        loss = -self.alpha * focal_weight * targets * log_probs

        # Apply reduction method
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
#TODO
class RARP_NVB_MultiClassModel(L.LightningModule):
    def __init__(self, 
                 InitWeight = torch.tensor([1,1]),  
                 schedulerLR:bool=False, 
                 lr:float = 1e-4,
                 Model:torch.nn.Module = None,
                 Num_Classes:int = 2,
                 L1:float = 1.31E-04,
                 L2:float = 0
        ) -> None:
        super().__init__()
        
        self.model = Model
        self.lossFN = FocalLoss() #torch.nn.CrossEntropyLoss()
        self.InitWeight = InitWeight
        self.scheduler = schedulerLR
        self.lr = lr
        self.Lambda_L1 = L1
        self.Lambda_L2 = L2
        self.num_classes = Num_Classes

        self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
        self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
        self.test_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
        self.f1ScoreTest = torchmetrics.F1Score("multiclass", num_classes=Num_Classes)
        
        self.val_loop = False

    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch):
        img, label = batch
        prediction = self(img) 
        predicted_labels = torch.softmax(prediction, dim=1)
        loss = self.lossFN(prediction, label)
        
        if self.Lambda_L1 is not None and not self.val_loop:
            loss_l1 = 0
            for params in self.model.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
       
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        self.val_loop = False
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        self.val_loop = True
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)

    def test_step(self, batch, batch_idx):
        self.val_loop = True
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2) 
        if self.scheduler:
            #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        else:
            return [optimizer]

class RARP_NVB_Model(L.LightningModule):
    def __init__(self, 
                 InitWeight = torch.tensor([1,1]), 
                 typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy, 
                 schedulerLR:bool=True, 
                 lr:float = 1e-4
                ) -> None:
        super().__init__()

        self.model = None
        self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
        self.InitWeight = InitWeight
        self.scheduler = schedulerLR
        self.lr = lr
        self.Lambda_L1 = None #1.31E-04 #
        self.Lambda_L2 = 0
        
        print (f"LR= {self.lr}, L1= {self.Lambda_L1}")

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        #self.f1Score = torchmetrics.F1Score('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')

    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch):
        img, label = batch
        if self.InitWeight is not None:
            self.lossFN.weight = self.InitWeight[label]
        
        label = label.float()
        prediction = self(img)[:,0] #.flatten()
        loss = self.lossFN(prediction, label)
        
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
        
        predicted_labels = torch.sigmoid(prediction)
        
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        #self.f1Score.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("hp_metric", self.val_acc, on_step=False, on_epoch=True)
        #self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2) 
        if self.scheduler:
            #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        else:
            return optimizer
    
class RARP_Ensemble(L.LightningModule):
    def __init__(self, ListModels,  
                 InitWeight = torch.tensor([1,1]), 
                 typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy, 
                 schedulerLR:bool=False, 
                 lr:float = 1e-4) -> None:
        super().__init__()
        
        self.ListModels = ListModels
        #for m in self.ListModels:
        #    m.freeze()
        input_p = len(self.ListModels)
        
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_p, out_features=128),
            torch.nn.SiLU(),
            #torch.nn.Dropout(0.1), #0.5
            torch.nn.Linear(128, 1),
            torch.nn.Sigmoid()
        )
        
        self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCELoss(reduction="sum")
        self.InitWeight = InitWeight
        self.scheduler = schedulerLR
        self.lr = lr

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1Score = torchmetrics.F1Score('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
    def forward(self, data):
        data = data.float()
        p = [m(data) for m in self.ListModels]
        p = torch.cat(p, dim=1)
        x = self.classifier(p)
        return x
    
    def _shared_step(self, batch):
        img, label = batch
        if self.InitWeight is not None:
            self.lossFN.weight = self.InitWeight[label]
        
        label = label.float()
        prediction = self(img)[:,0] #.flatten()
        loss = self.lossFN(prediction, label)
        predicted_labels = prediction
        
        return loss, label, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.f1Score.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 
        if self.scheduler:
            #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, 1e-8, verbose=True)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        else:
            return [optimizer]
    
class RARP_NVB_ResNet18(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

class RARP_NVB_DaVit(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = timm.create_model("davit_small.msft_in1k", pretrained=True, num_classes=1)   

class RARP_NVB_EfficientNetV2_Deep(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
        
        self.model.classifier.append(torch.nn.SiLU(True))
        self.model.classifier.append(torch.nn.Linear(128, 8))
        self.model.classifier.append(torch.nn.SiLU(True))
        self.model.classifier.append(torch.nn.Linear(8, 1))

class RARP_NVB_ResNet50_Deep(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features
        self.model.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features=tempFC_ft, out_features=128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Linear(8, 1)
        )
        
    def forward(self, img):
        img = img.float()
        pred = self.model(img)
        return pred
    
class RARP_NVB_ResNet50_Deep_OPTuna(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, output_dims=[128, 8], dropuot=0.2, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        layers = []
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
                
        inputDim = self.model.fc.in_features
        
        for outputDim in output_dims:
            layers.append(torch.nn.Linear(inputDim, outputDim))
            layers.append(torch.nn.SiLU(True))
            layers.append(torch.nn.Dropout(dropuot))
            inputDim = outputDim
        
        layers.append(torch.nn.Linear(inputDim, 1))
        
        self.model.fc = torch.nn.Sequential(*layers)
        
    def forward(self, img):
        img = img.float()
        pred = self.model(img)
        return pred

class RARP_NVB_ResNet50_V2(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features
        self.model.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features=tempFC_ft, out_features=128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8)
        )
        self.model.fc2 = torch.nn.Linear(8 + InputNeurons, 1)
        
    def forward(self, data):
        img, extra = data
        img = img.float()
        x = torch.nn.functional.silu(self.model(img), True)
        extradata = torch.concat((x, extra), dim=1)
        pred = self.model.fc2(extradata)
        
        return pred
    
class RARP_NVB_ResNet50_V3(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
                
        self.model.extraFC = torch.nn.Linear(InputNeurons, 128)
        
        self.model.fc2 = torch.nn.Linear(256, 1)
        
    def forward(self, data):
        img, extra = data
        img = img.float()
        
        x = torch.nn.functional.silu(self.model(img), True)
        y = torch.nn.functional.silu(self.model.extraFC(extra), True)
        x = torch.concat((x, y), dim=1)
        
        pred = self.model.fc2(x)
                
        return pred

class RARP_NVB_ResNet50_V1(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features + InputNeurons
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
                
        self.Decoder = torch.nn.Sequential(*list(self.model.children())[:-2])
        
    def forward(self, data):
        img, extra = data
        img = img.float()
        featureMap = self.Decoder(img)
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1)) 
        Cont_Net = torch.flatten(Cont_Net, 1)
        
        Cont_Net = torch.concat((Cont_Net, extra), dim=1)
        
        pred = self.model.fc(Cont_Net)
                
        return pred
        
class RARP_NVB_VAN(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = van.van_b2(pretrained = True)
        tempFC_ft = self.model.head.in_features
        self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)        
        
class RARP_NVB_ResNet50(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

class RARP_NVB_MLP(torch.nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=512, bottleneck=256, n_layers=3, norm_last_layer=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.activationFN = torch.nn.GELU()
        
        if n_layers == 1:
            self.mlp = torch.nn.Linear(in_dim, bottleneck)
        else:
            layers = [torch.nn.Linear(in_dim, hidden_dim)]
            #layers.append(torch.nn.BatchNorm1d(hidden_dim)) # Add
            layers.append(self.activationFN)
            layers.append(torch.nn.Dropout(0.30))
            for _ in range(n_layers - 2):
                layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
                layers.append(self.activationFN)
            layers.append(torch.nn.Linear(hidden_dim, bottleneck))
            self.mlp = torch.nn.Sequential(*layers)
            
        self.apply(self._init_weights)
        self.last_layer = torch.nn.utils.weight_norm(
            torch.nn.Linear(bottleneck, out_dim, bias=False)
        )
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False
            
    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = self.mlp(x)
        x = torch.nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        
        return x

class RARP_NVB_DINO_Wrapper(torch.nn.Module):
    def __init__(self, backbone:torch.nn.Module, new_head:torch.nn.Module, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.backbone = backbone
        self.head = new_head
        
    def forward(self, x):
        if isinstance(x, list):
            n_crops = len(x)
            concatCrops = torch.cat(x, dim=0)
        else:
            concatCrops = x
            n_crops = 1

        embedding = self.backbone(concatCrops)
        logitis = self.head(embedding)
        chunks = logitis.chunk(n_crops)
        
        return chunks
    
class RARP_NVB_DINO_Loss(torch.nn.Module):
    def __init__(self, out_dim:int, teacher_Thao:float = 0.04, student_Thao:float = 0.1, center_momentum:float = 0.9, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.S_Thao = student_Thao
        self.T_Thao = teacher_Thao
        self.C_Momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))
        
    def forward(self, s_Output, t_Output):
        sTemp = [s / self.S_Thao for s in s_Output]
        tTemp = [(t - self.center) / self.T_Thao for t in t_Output]
        
        studentSM = [torch.nn.functional.log_softmax(s, dim=-1) for s in sTemp]
        teacherSM = [torch.nn.functional.softmax(t, dim=-1).detach() for t in tTemp]
        
        total_loss = 0
        n_loss_terms = 0
        
        for t_ix, t in enumerate(teacherSM):
            for s_ix, s in enumerate(studentSM):
                if (t_ix == s_ix) and (len(teacherSM) > 1):
                    continue
                
                loss = torch.sum(-t * s, dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        
        total_loss /= n_loss_terms
        self.update_center(t_Output)
        
        return total_loss
    
    @torch.no_grad()
    def update_center(self, t_output):
        b = torch.cat(t_output).mean(dim=0, keepdim=True)
        self.center = self.center * self.C_Momentum + b * (1 - self.C_Momentum)
        
class RARP_NVB_DINO_RestNet50_Deep(L.LightningModule):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        TypeLoss=TypeLossFunction.CrossEntropy,
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        L1:float = None,
        L2:float = 0,
    ) -> None:
        super().__init__()
    
        self.lr =  lr
        self.Lambda_L1 = L1
        self.Lambda_L2 = L2
        self.threshold = threshold
        self.momentum_teacher = momentum_teacher
        self.out_dim = 512
        self.in_dim = 2048

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
        self.teacher_Labels = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
        self.student = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) #torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
        self.teacher_Features = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        
        self.student.fc = torch.nn.Identity()
        self.teacher_Features.fc = torch.nn.Identity()
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        for parms in self.teacher_Labels.model.parameters():
            parms.requires_grad = False
            
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False
        
        self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, 0.04, 0.1, momentum_teacher)
        self.lossFN_KD = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
        
        #self.lossFH_KLDiv = torch.nn.KLDivLoss(reduction="batchmean")
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(self.out_dim, 128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Linear(8, 1)
        )
        
    def forward(self, data, val_step:bool = True):
        if val_step:
            data = data.float()
            dataClassificator, dataTeacher, dataStudent = data, data, data
        else:
            data = [d.float() for d in data]
            dataClassificator, dataTeacher, dataStudent = data[0], data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        TeacherLabels = self.teacher_Labels(dataClassificator)
        Student = self.student(dataStudent)
        
        # es se evaluan todas las salidas del estuidaitne
        #if isinstance(dataStudent, list):
        #    #index = np.random.randint(0, len(dataStudent))
        #    temp = self.student(dataStudent)
        #    CatS_Classifier = torch.cat(temp, dim=0)
        #    meanS_Classifier = torch.zeros(self.in_dim)
        #    for dataS in temp:
        #        meanS_Classifier += dataS
        #    #S_Classifier = self.student(dataStudent[index])
        #    S_Classifier = meanS_Classifier / len(dataStudent)
        #else:
        #    S_Classifier = Student
            
        Cont_Net = torch.cat(Student, dim=0)
                
        pred = self.clasiffier(Cont_Net)
        
        if not val_step:
            TeacherLabels = [self.teacher_Labels(dataClassificator) for _ in range(len(dataStudent))]
            TeacherLabels = torch.cat(TeacherLabels, dim=0)
        
        TeacherLabelsPred = torch.sigmoid(TeacherLabels.flatten())
        PseudoLabels = (TeacherLabelsPred > self.threshold) * 1.0
                        
        return (pred.flatten(), PseudoLabels, TeacherLabels.flatten()), (TeacherDino, Student)
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        if not val_step:
            label = torch.cat([label for _ in range(len(img))], dim=0)
        
        label = label.float()
        KD_Prediction, DINO_Loss = self(img, val_step)
        TeacherF, StudentF = DINO_Loss
        prediction, PseudoLabels, teacherOutputs = KD_Prediction
        
        predicted_labels = torch.sigmoid(prediction)
        
        ##verstion 1
        W_Alpha, W_Beta = (1, 0.5)#(1, 0.5)
        loss = W_Alpha * self.lossFN_KD(prediction, PseudoLabels) + W_Beta * self.lossFN_KD(prediction, label)
        
        #version 2
        #thao_KD = 1#5.0
        #W_Alpha, W_Beta = (0.6, 0.4)
        
        #softTeacher = torch.sigmoid(teacherOutputs/thao_KD)
        #softStudent = torch.sigmoid(prediction/thao_KD)
        
        #loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
        #loss_hl = self.lossFN_KD(prediction, label)
        
        #loss = W_Alpha * loss_hl + W_Beta * loss_sl
        
        #loss = W_Alpha * self.lossFN_KD(prediction, label) + W_Beta * (self.lossFH_KLDiv(softStudent, softTeacher) * (thao_KD ** 2))
        loss += (self.lossFN_DINO(StudentF, TeacherF) if not val_step else 0)
        
        if not val_step:
            self.logger.experiment.add_histogram ("Teacher", TeacherF[0])
            self.logger.experiment.add_histogram ("Student", StudentF[1])
                
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.student.parameters(): # aqui
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
        
        #return loss, PseudoLabels, predicted_labels
        return loss, label, predicted_labels
        
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, False)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)

        return loss
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        with torch.no_grad():
            for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
                teacher_ps.data.mul_(self.momentum_teacher)
                teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
            
            self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
            
                
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr, weight_decay=self.Lambda_L2) 
        
        return [optimizer]

class RARP_NVB_DINO_VAN(RARP_NVB_DINO_RestNet50_Deep):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher: float = 0.9995, 
        lr: float = 0.0001, 
        L1: float = None, 
        L2: float = 0
    ) -> None:
        super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
        
        self.in_dim = 512
        
        self.student = van.van_b2(pretrained = True, num_classes = -1) 
        self.teacher_Features = van.van_b2(pretrained = True, num_classes = -1) 
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False

class RARP_NVB_DINO_ViT(RARP_NVB_DINO_RestNet50_Deep):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher: float = 0.9995, 
        lr: float = 0.0001, 
        L1: float = None, 
        L2: float = 0
    ) -> None:
        super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
        
        self.in_dim = 768
        
        self.student = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        self.teacher_Features = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        
        self.student.heads = torch.nn.Identity()
        self.teacher_Features.heads = torch.nn.Identity()
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
        )
        
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False

class RARP_MAE(L.LightningModule):
    def __init__(
        self, 
        backbone:str, 
        mask_ratio:float = 0.75,
        patch_size: int = 16,
        #img_size: int = 224,
        lr: float = 1e-4,
        hiden_channels = [512, 256, 128, 64, 32],
        img_mean_std = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]],
        activation_fn = torch.nn.ReLU(True)
    ):
        super().__init__()
        
        self.save_hyperparameters(ignore=["activation_fn", "img_mean_std"])
        
        if img_mean_std is not None:
            mean, std = img_mean_std
            self.register_buffer("mean_IMG", torch.tensor(mean).view(3, 1, 1))
            self.register_buffer("std_IMG", torch.tensor(std).view(3, 1, 1))
        else:
            self.mean_IMG = None
            self.std_IMG = None
        
        self.loss_fn = torch.nn.MSELoss()
        
        match backbone:
            case "resnet":
                model = torchvision.models.resnet18()
                model.fc = torch.nn.Identity()
                self.encoder = model
                self.encoder_out_dim = 512
            case "van":
                model = van.van_b2()
                model.head = torch.nn.Identity()
                self.encoder = model
                self.encoder_out_dim = 512
            case "levit":
                self.encoder = timm.create_model("levit_192.fb_dist_in1k", pretrained=False, num_classes=0)
                self.encoder_out_dim = 384
                hiden_channels = [384, 256, 128, 64, 32]
                

        in_channel = self.encoder_out_dim
        layers = [
            torch.nn.Linear(in_channel, hiden_channels[0]*7*7),
            torch.nn.Unflatten(1, (hiden_channels[0], 7, 7)),
        ]
        
        for out_channel in hiden_channels[1:]:
            layers.append(torch.nn.ConvTranspose2d(in_channel, out_channel, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(torch.nn.BatchNorm2d(out_channel))
            layers.append(activation_fn)
            in_channel = out_channel
            
        #Last layer
        layers.append(torch.nn.ConvTranspose2d(in_channel, 3, kernel_size=3, stride=2, padding=1, output_padding=1))
                
        self.decoder = torch.nn.Sequential(*layers)
    
    def _denormalize(self, tensor:torch.Tensor):
        return tensor * self.std_IMG + self.mean_IMG
    
    def _mask_patches(self, imgs):
        B, C, H, W = imgs.shape
        ph, pw = self.hparams.patch_size, self.hparams.patch_size
        gh, gw = H // ph, W // pw
        
        mask:torch.Tensor = (torch.rand(B, gh * gw, device=imgs.device) >= self.hparams.mask_ratio)
        mask = mask.reshape(B, 1, gh, gw)   # [B,1,gh,gw]
        
        # expand mask to full image
        mask = mask.repeat_interleave(ph, dim=2)    # [B,1,gh*ph,gw]
        mask = mask.repeat_interleave(pw, dim=3)    # [B,1,gh*ph,gw*pw] == [B,1,H,W]
        
        mask = mask.expand(-1, C, -1, -1)
        
        imgs_masked = imgs * mask
        return imgs_masked
    
    def forward(self, data, val_step=False):
        if not val_step:
            data = [d.float() for d in data] #[0] original Img, [1] augmented Img
        else:
            data = [data.float(), data.float()]
        
        imgs_masked = self._mask_patches(data[1])
        latent = self.encoder(imgs_masked)   # [B, 512]
        reconstructed_image = self.decoder(latent)  # [B, 3, 224, 224]
        
        return reconstructed_image, data[0]
    
    def _shared_step(self, batch, val_step=False):
        img, _ = batch
        res, true_img = self(img, val_step)
        
        loss_MSE =  self.loss_fn(res, true_img)
        
        return loss_MSE, res, true_img
    
    def training_step(self, batch, batch_idx):
        loss, img, _ = self._shared_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True, sync_dist=True)
        
        if batch_idx % 500 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(img), 0, 1)
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
            
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, img, target_img = self._shared_step(batch, True)
                
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        
        
        if self.mean_IMG is not None and self.std_IMG is not None:
            img = torch.clip(self._denormalize(img), 0, 1)
            target_img = torch.clip(self._denormalize(target_img), 0, 1)
            
            ssim = piq.ssim(img, target_img, data_range=1.0)
            psnr = piq.psnr(img, target_img, data_range=1.0)
            
            self.log("val_ssim", ssim, on_epoch=True)
            self.log("val_psnr", psnr, on_epoch=True)
            
            if batch_idx % 100 == 0:
                grid = torchvision.utils.make_grid(img)
                self.logger.experiment.add_image('val_reconstructed_images', grid, self.global_step)
            
                
    def on_after_backward(self):
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        self.log("grad_norm", total_norm)
        
        if total_norm < 1e-8:
            self.log("grad_warning", "Vanishing gradient suspected!")
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)  
        
        return [optimizer]
       
class RARP_Encoder_DINO(L.LightningModule):
    def __init__(self, 
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        Teacher_T:float = 0.04,
        Student_T:float = 0.1,
        max_epochs:int = 100,
        total_steps:int = None
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        
               
        self.lr =  lr
        self.momentum_teacher = momentum_teacher
        self.out_dim = 65536
        self.in_dim = 768 #512
        
        self.embs_t = []
        self.embs_s = []
        
        #self.student = van.van_b2(num_classes = 0)
        #self.teacher = van.van_b2(num_classes = 0)
        #weights=torchvision.models.ConvNeXt_Small_Weights.DEFAULT
        
        self.student = torchvision.models.convnext_small() 
        self.student.classifier[-1] = torch.nn.Identity()
        
        self.teacher = torchvision.models.convnext_small()
        self.teacher.classifier[-1] = torch.nn.Identity()
        
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True)
        )
        
        self.teacher = RARP_NVB_DINO_Wrapper(
            self.teacher,
            RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True)
        )
        
        self.teacher.load_state_dict(self.student.state_dict())
        
        for parms in self.teacher.parameters():
            parms.requires_grad = False
            
        self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
        
    def forward(self, data):
        data = [d.float() for d in data]
        dataTeacher, dataStudent = data[0:3], data
        
        teacher_features = self.teacher(dataTeacher)
        student_features = self.student(dataStudent)
        
        return teacher_features, student_features
    
    def _shared_step(self, batch):
        img, _ = batch
        t, s = self(img)
        
        loss_Dino = self.lossFN_DINO(s, t)
        
        return loss_Dino
        
    def training_step(self, batch, batch_idx):
        loss = self._shared_step(batch)    
        
        self.log("train_loss", loss, on_epoch=True, sync_dist=True)
        
        return loss
    
    def on_train_batch_end(self, outputs, batch, batch_idx):  
        #step = self.global_step
        #
        #m = 1.0 - (1.0 - self.hparams.momentum_teacher) * (
        #    (1 + math.cos(math.pi * step / self.hparamstotal_steps)) / 2
        #)
              
        with torch.no_grad():
            for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher.parameters()):
                teacher_ps.data.mul_(self.momentum_teacher)
                teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
    
    def validation_step(self, batch, batch_idx):
        imgs, _ = batch
        
        embs = self.teacher(imgs.float())[0]
        self.embs_t.append(embs)
        
        embs = self.student(imgs.float())[0]
        self.embs_s.append(embs)
        
        
    
    def on_validation_epoch_end(self):
        emds = torch.cat(self.embs_t, dim=0).cpu().numpy()
        self.embs_t.clear()
        #silhouette
        kmeans = KMeans(n_clusters=10, random_state=505).fit(emds)
        sil = silhouette_score(emds, kmeans.labels_)
        self.log("val_silhouette_teacher", sil)
        
        emds = torch.cat(self.embs_s, dim=0).cpu().numpy()
        self.embs_s.clear()
        #silhouette
        kmeans = KMeans(n_clusters=10, random_state=505).fit(emds)
        sil = silhouette_score(emds, kmeans.labels_)
        self.log("val_silhouette_student", sil)
        
        
                
    def on_after_backward(self):
        norms = [p.grad.data.norm(2).item() for p in self.parameters() if p.grad is not None]
        avg_layer_norm = sum(norms) / len(norms)
        
        self.log("avg_grad_norm", avg_layer_norm)
                    
    def configure_optimizers(self):        
        optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr)
                
        return [optimizer]

class RARP_NVB_Classification_Head(torch.nn.Module):
    def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.activation = activation_fn
        
        if len (layer) == 0:        
            self.head = torch.nn.Linear(in_features, out_features)
        else:
            temp_head = []
            next_input = in_features
            for num in layer:
                temp_head.append(torch.nn.Linear(next_input, num))
                temp_head.append(self.activation)
                temp_head.append(torch.nn.Dropout(0.4))
                next_input = num
                
            temp_head[-1] = torch.nn.Dropout(0.2)
            temp_head.append(torch.nn.Linear(next_input, out_features))
            
            self.head = torch.nn.Sequential(*temp_head)
            del temp_head
    
    def forward(self, x):
        return self.head(x)
   
class RARP_Encoder_DINO_AUX_task(RARP_Encoder_DINO):
    def __init__(
        self, 
        momentum_teacher = 0.9995, 
        lr = 0.0001, 
        Teacher_T = 0.04, 
        Student_T = 0.1, 
        max_epochs = 100, 
        total_steps = None,
        aux_lambda = 0.3
    ):
        super().__init__(momentum_teacher, lr, Teacher_T, Student_T, max_epochs, total_steps)
        
        self.angles = [-90, 0, 90, 180]
        self.angles_labels = torch.from_numpy(LabelEncoder().fit_transform(self.angles))
        
        self.classifier = RARP_NVB_Classification_Head(self.in_dim, len(self.angles), layer=[128], activation_fn=torch.nn.GELU())
        
        self.lossFN_Aux = torch.nn.CrossEntropyLoss()
        
        self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=len(self.angles))
        self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=len(self.angles))
    
    def _rotation_labels_generator(self, batch_size:int, angles_batch=None, rand:bool=False) -> torch.Tensor: 
        assert len(angles_batch) > 0, "Empty list, angles list shuld have more than 2 values"
        
        if rand:
            return torch.tensor([angles_batch[np.random.randint(len(angles_batch))] for _ in range(batch_size)])
        else:
            return torch.tensor([angles_batch[i % len(angles_batch)] for i in range(batch_size)])
        
    def _rotate_img(self, imgs:torch.Tensor, angles_batch=[]) -> torch.Tensor:
        assert len(angles_batch) > 0, "Empty list, angles list shuld have more than 2 values"
        
        list_imgs = []
        
        for i, angle_idx in enumerate(angles_batch):
            list_imgs.append(torchvision.transforms.functional.rotate(imgs[i, ...], self.angles[angle_idx]))
            
        return torch.stack(list_imgs, dim=0)
        
    def forward(self, data, rot_data, val_step:bool=False):
        if not val_step:
            data = [d.float() for d in data]
            dataTeacher, dataStudent = data[0:3], data
        else:
            data = data.float()
            dataTeacher, dataStudent = data, data
            
        dataAux = rot_data.float()
        
        teacher_features = self.teacher(dataTeacher)
        student_features = self.student(dataStudent)
        
        aux_rot = self.student.backbone(dataAux)
        aux_rot = self.classifier(aux_rot)
        
        return teacher_features, student_features, aux_rot
    
    def _shared_step(self, batch, val_step:bool=False):
        img, _ = batch
        batch_size = img.size(0) if val_step else img[0].size(0)
        current_device = img.device if val_step else img[0].device
        labels = self._rotation_labels_generator(batch_size, self.angles_labels, not val_step).to(current_device)
        rot_img = self._rotate_img(img if val_step else img[0], labels).to(current_device)
                
        t, s, aux = self(img, rot_img, val_step)
        
        if val_step:
            self.embs_t.append(t[0])
            self.embs_s.append(s[0])
                
        aux = torch.softmax(aux, dim=1)
        loss_aux = self.lossFN_Aux(aux, labels)
        loss_aux = self.hparams.aux_lambda * loss_aux
        
        loss_Dino = self.lossFN_DINO(s, t) if not val_step else torch.tensor(0, device=current_device, dtype=torch.float32)
        
        loss = loss_Dino + loss_aux
        
        return loss, labels, aux, (loss_Dino, loss_aux)
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, pred_labels, losses = self._shared_step(batch, False)    
        
        self.log("train_loss", loss, on_epoch=True, sync_dist=True)
        self.train_acc.update(pred_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_loss_DINO", losses[0], on_epoch=True, on_step=False)
        self.log("train_loss_AUX", losses[1], on_epoch=True, on_step=False)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, pred_labels, losses = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(pred_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_loss_DINO", losses[0], on_epoch=True, on_step=False)
        self.log("val_loss_AUX", losses[1], on_epoch=True, on_step=False)
        
    def configure_optimizers(self):        
        optimizer = torch.optim.AdamW([
            {"params": self.student.parameters(), "lr": self.lr},
            {"params": self.classifier.parameters(), "lr": self.lr}
        ])
                
        return [optimizer]
        

class RARP_NVB_DINO_MultiTask(L.LightningModule):
    # Define a hook function to capture the output
    def _hook_fn_Student(self, module, input, output):
        self.last_conv_output_S = output
        
    def _hook_fn_Teacher(self, module, input, output):
        self.last_conv_output_T = output
    
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy,
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        L1:float = None,
        L2:float = 0,
        std: float = None,
        mean: float = None,
        SoftAdptAlgo:int = 0,
        SoftAdptBeta:float = 0.1,
        Teacher_T:float = 0.04,
        Student_T:float = 0.1,
        intermittent:bool = False
    ) -> None:
        super().__init__()
        
        self.intermittent_train = intermittent
        
        self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
        self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
    
        self.lr =  lr
        self.Lambda_L1 = L1
        self.Lambda_L2 = L2
        self.Teacher_t = Teacher_T
        self.Studet_T = Student_T
        self.momentum_teacher = momentum_teacher
        self.out_dim = 1024
        self.in_dim = 512
        self.weights = torch.tensor([1,1,1])
        
        self.softAdapt = NormalizedSoftAdapt(SoftAdptBeta) if SoftAdptAlgo == 1 else LossWeightedSoftAdapt(SoftAdptBeta)
        self.loss_history = {
            'loss_DINO': [], 
            'loss_Reconstruction': [],
            'loss_Binary': [],
        }

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
        self.student = van.van_b2(pretrained = True, num_classes = 0)
        self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0)
        
        self.decoder = DynamicDecoder(input_channels=1024) 
               
        self.student = RARP_NVB_DINO_Wrapper(
            self.student,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
        )
        
        self.teacher_Features = RARP_NVB_DINO_Wrapper(
            self.teacher_Features,
            RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
        )
        
        self.teacher_Features.load_state_dict(self.student.state_dict())
        for parms in self.teacher_Features.parameters():
            parms.requires_grad = False
        
        self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
        self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
        self.ReconstructionLoss = torch.nn.MSELoss()
        
        self.last_conv_output_T = None
        self.last_conv_output_S = None
        
        self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher)
        self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(1024, 128),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.4),
            
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Dropout(0.2),
            
            torch.nn.Linear(8, 1)
        )
                
        print(f"lr={self.lr}, L1={self.Lambda_L1}")
        
    def _denormalize(self, tensor:torch.Tensor):
        # Move mean and std to the same device as the input tensor
        mean = self.mean_IMG.to(tensor.device)
        std = self.std_IMG.to(tensor.device)
        return tensor * std + mean
    
    def _calc_L1(self, params):
        l1 = 0
        for p in params:
            l1 += torch.sum(torch.abs(p))
        return self.Lambda_L1 * l1
    
    def _calc_weights(self, log_weights:bool = True):
        self.weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]),
            torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_Reconstruction"]),
            torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
            verbose=False
        )

        if log_weights:
            self.log("W_loss_img", self.weights[1], on_epoch=True, on_step=False)
            self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_GT", self.weights[2], on_epoch=True, on_step=False)

        self.loss_history = {
            'loss_DINO': [], 
            'loss_Reconstruction': [],
            'loss_Binary': [],
        }
        
    def forward(self, data, val_step:bool = True):
        if val_step:
            data = data.float()
            dataTeacher, dataStudent = data, data
        else:
            data = [d.float() for d in data]
            dataTeacher, dataStudent = data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        Student = self.student(dataStudent)
                       
        if not val_step:
            NumChunks = len(dataStudent)
            S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
            self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
        
        cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
        
        reconstructed_image = self.decoder(cat_features)
        
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
        pred = self.clasiffier(Cont_Net)
                        
        return pred, (Student, TeacherDino), reconstructed_image
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        prediction, features, new_image = self(img, val_step)
        StudentF, TeacherF = features
        
        if isinstance(self.clasiffier, torch.nn.Sequential):
            if self.clasiffier[-1].out_features == 1:
                prediction = prediction.flatten()
        elif isinstance(self.clasiffier, (NOAH, RARP_NVB_Classification_Head)):
            prediction = prediction.flatten()        

        predicted_labels = torch.sigmoid(prediction)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
        
        #DINO Loss
        loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_HL += self._calc_L1(self.clasiffier.parameters())
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_DINO"].append(loss_Dino.item())
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
        
        return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
        
    def on_train_epoch_start(self):
        if self.current_epoch % 2 == 0 and self.current_epoch != 0:
            self._calc_weights()
            
        if self.intermittent_train and self.current_epoch != 0:
            
            par_epoch = (self.current_epoch % 2 == 0)
            
            for parms in self.student.backbone.parameters():
                parms.requires_grad = par_epoch
                
            for parms in self.decoder.parameters():
                parms.requires_grad = not par_epoch
                
            for parms in self.clasiffier.parameters():
                parms.requires_grad = not par_epoch
                
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch, False)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("train_loss_DINO", losses[0], on_epoch=True, on_step=False)
        self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
        
        if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)

        return loss
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        with torch.no_grad():
            for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
                teacher_ps.data.mul_(self.momentum_teacher)
                teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
            
            #self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
    
    def on_after_backward(self):
        total_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        self.log("grad_norm", total_norm)
        
        if total_norm < 1e-8:
            self.log("grad_warning", "Vanishing gradient suspected!")
                
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("val_loss_DINO", losses[0], on_epoch=True, on_step=False)
        self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, losses = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
        if self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            
            imgOrig = torch.clip(self._denormalize(batch[0])/255, 0, 1)
            imgOrig = imgOrig[:, [2, 1, 0], :, :]
            
            imgReconstruction = torch.cat((imgOrig, imgReconstruction), dim=0)
            
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)  #, weight_decay=self.Lambda_L2
        
        return [optimizer]

class Scalar_AttnPooling(torch.nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        
        self.proj = torch.nn.Linear(1, hidden_dim)
        self.attn = torch.nn.Linear(hidden_dim, 1)
        self.pooler_cls = torch.nn.Linear(1, 1)
        
    def forward(self, logits_bt):
        x = logits_bt.unsqueeze(-1)     #[B, T, 1]
        
        h = torch.tanh(self.proj(x))    #[B, T, hidden_dim]
        a = self.attn(h).squeeze(-1)    #[B, T]
        
        w = torch.nn.functional.softmax(a, 1) # [B, T] Weights for each T, to idetnfy the best frame for classification 
        z = (w.unsqueeze(-1) * x).sum(dim=1) # [B, 1]
        z = self.pooler_cls(z)
        
        return z

class Scalar_TCN(torch.nn.Module):
    def __init__(self, hidden_dim=64, layers=3):
        super().__init__()
        
        self.in_proj = torch.nn.Conv1d(1, hidden_dim, kernel_size=1)
        
        blocks = []
        for i in range(layers):
            d = 2 ** i
            blocks += [
                torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=d, dilation=d),
                torch.nn.SiLU(),
                torch.nn.Dropout(0.3),
                torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1),
                torch.nn.SiLU(),
                torch.nn.Dropout(0.2)
            ]
        
        self.tcn = torch.nn.Sequential(*blocks)
        self.head = torch.nn.Linear(hidden_dim, 1)
        
    def forward (self, logits_bt):
        x = logits_bt.unsqueeze(1)  #[B, 1, T] to do the Conv over the channel or the logits values
        x = self.in_proj(x)         #[B, hidden_dim, T]
        x = self.tcn(x)             #[B, hidden_dim, T]
        x = x.mean(dim=2)           #[B, hidden_dim] global average over T
        
        x = self.head(x)            #[B, 1]
        
        return x

class ModuleWrapper(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        
    def forward(self, x, dummy_arg=None):
        assert dummy_arg is not None
        x = self.module(x)
        
        return x
    
class Chomp1d(torch.nn.Module):
    """
    Remove extra padding at the end to maintain causality.
    If you pad (padding) at left, you may need to chomp off the right extra.
    """
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size
    def forward(self, x):
        # x has shape [B, C, T]
        if self.chomp_size == 0:
            return x
        return x[:, :, :-self.chomp_size]
    
class TemporalBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, padding, dropout=0.0):
        """
        A residual block in TCN with two dilated conv layers (same dilation).
        """
        super().__init__()
        self.conv1 = torch.nn.Conv1d(in_channels, out_channels,
                               kernel_size,
                               stride=stride,
                               padding=padding,
                               dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = torch.nn.ReLU()
        self.dropout1 = torch.nn.Dropout(dropout)

        self.conv2 = torch.nn.Conv1d(out_channels, out_channels,
                               kernel_size,
                               stride=stride,
                               padding=padding,
                               dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = torch.nn.ReLU()
        self.dropout2 = torch.nn.Dropout(dropout)

        self.downsample = (torch.nn.Conv1d(in_channels, out_channels, 1)
                           if in_channels != out_channels else None)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        """
        x: [B, in_channels, T]
        returns: [B, out_channels, T]
        """
        out = self.conv1(x)
        out = self.chomp1(out)
        out = self.relu1(out)
        out = self.dropout1(out)

        out = self.conv2(out)
        out = self.chomp2(out)
        out = self.relu2(out)
        out = self.dropout2(out)

        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)
    
class TemporalConvNet(torch.nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=3, dropout=0.0):
        """
        num_inputs: number of input channels (features)
        num_channels: list of output channels per layer, e.g. [64, 64, 128]
        kernel_size: convolution kernel size (e.g. 3)
        dropout: dropout rate in blocks
        """
        super().__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            in_ch = num_inputs if i == 0 else num_channels[i - 1]
            out_ch = num_channels[i]
            dilation = 2 ** i
            # padding should be such that the output has length T (causal)
            padding = (kernel_size - 1) * dilation
            layers.append(
                TemporalBlock(in_ch, out_ch,
                              kernel_size=kernel_size,
                              stride=1,
                              dilation=dilation,
                              padding=padding,
                              dropout=dropout)
            )
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x):
        """
        x: [B, T, C_in]
        returns: [B, C_out_last, T]
        """
        x = x.permute(0, 2, 1) #[B, C_in, T]
        x = self.network (x)  #[B, C_out, T]
        x = x.permute(0, 2, 1) #[B, T, C_out]
        
        return x

class RARP_NVB_DINO_MultiTask_A5_Video(L.LightningModule):
    def __init__(
        self, 
        base_model_path = None,
        lr = 0.0001, 
        wd = 0.01,
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None,
        head_type:int = 0, #None = 0, linear = 1, Attn. Pooling = 2, TCN = 3
        chunks_loading:int = 50 
    ):
        super().__init__()
        
        self.lr = lr
        self.wd = wd
        self.chunks = chunks_loading
        
        if base_model_path is not None:
            self.base_model = RARP_NVB_DINO_MultiTask.load_from_checkpoint(base_model_path)
            self.base_model.eval()
            for param in self.base_model.parameters():
                param.requires_grad = False
        else:
            self.base_model = RARP_NVB_DINO_MultiTask()
            
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
        
        
        #match(head_type):
        #    case 1:
        #        #Linear
        #        self.head = torch.nn.Linear(600, 1)
        #    case 2:
        #        #Attn. pooling
        #        self.head = Scalar_AttnPooling(128)
        #    case 3:
        #        #TCN
        #        self.head = Scalar_TCN(64, 1)
        #    case _:
        self.head = torch.nn.Linear(600, 1)   
    
    def _shared_video_step(self, batch:list[torch.Tensor], val_step:bool = False):
        video, label = batch
        B, T, C, H, W = video.shape
        video = video.float() #[B, T, C, H, W]
        label = label.float() #[B]
        
        chunk_T = self.chunks
        pred_bt = []
        
        def _fn(inp):
            pred, *_ = self.base_model(inp)
            return pred
        
        for t0 in tqdm(range(0, T, chunk_T), desc=f"Video Analysis in {chunk_T} chunk", leave=False):
            t1 = min(T, t0 + chunk_T)
            x = video[:, t0:t1].reshape(-1, C, H, W).contiguous(memory_format=torch.channels_last)
            
            pred = torch_ckp.checkpoint(_fn, x)
            pred_bt.append(pred.view(B, t1-t0, -1))
            
        pred_video = torch.cat(pred_bt, dim=1).flatten(start_dim=1)
                
        pred_video = self(pred_video, val_step) #[B, 1]
        pred_video = pred_video.flatten() #[B] to match labels shape
                    
        predicted_labels = torch.sigmoid(pred_video)
        loss = self.lossFN(pred_video, label)
        
        return loss, label, predicted_labels
        
    def forward(self, data, val_step:bool = True):
        #if self.head is None:
        #    pred_video = data.mean(dim=1)
        #else:
        pred_video = self.head(data)
            
        return pred_video
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
         
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.head.parameters(), lr=self.lr, weight_decay=self.wd)  #, weight_decay=self.Lambda_L2
        
        return [optimizer]

class RARP_NVB_VIDEO_3D_ResNet(L.LightningModule):
    def __init__(
        self, 
        lr = 0.0001, 
        wd = 0.01,
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None,
        chunks_loading:int = 50,
        str_path:str = None
    ):
        super().__init__()
        
        self.lr = lr
        self.wd = wd
        self.chunks = chunks_loading
        
        if str_path is None:
            base_2D_model =  torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
                
            self.check_pt = True
            self.base_model = I3DResNet50( base_2D_model )
            self.features = self.base_model.fc.in_features
            self.base_model.fc = torch.nn.Linear(self.features, 1) 
        else:
            base_2D_model = torchvision.models.resnet50()
            in_f = base_2D_model.fc.in_features
            base_2D_model.fc = torch.nn.Linear(in_f, 1)
            
            base_2D_model.load_state_dict(torch.load(str_path))
            
            self.check_pt = True
            self.base_model = I3DResNet50( base_2D_model )
                
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        self.test_probs = []
        self.test_targets = []
        
    def _shared_video_step(self, batch:list[torch.Tensor], val_step:bool = False):
        video, label = batch
        
        video = video.float() #[B, T, C, H, W]
        label = label.float() #[B]
  
        pred_video = self(video) #[B, T, 1]
        #pred_video = pred_video.mean(dim=1) #[B, 1]
        pred_video = pred_video.flatten() #[B] to match labels shape
                    
        predicted_labels = torch.sigmoid(pred_video)
        loss = self.lossFN(pred_video, label)
        
        return loss, label, predicted_labels
    
    def forward(self, video:torch.Tensor):
        T = video.shape[1]
        chunk_T = self.chunks
        pred_bt = []
        
        assert T > chunk_T, f"The Time dim is smaller, than chunk size [T={T}, chunk={chunk_T}]"
        
        def _fn(inp):
            pred = self.base_model(inp)
            return pred
                
        video = video.permute(0, 2, 1, 3, 4)
        for t0 in tqdm(range(0, T, chunk_T), desc=f"Video Analysis in {chunk_T} chunk", leave=False): # Loop for each chunk
            t1 = min(T, t0 + chunk_T)
            x = video[:, :, t0:t1]#.contiguous(memory_format=torch.channels_last_3d)
            
            pred = torch_ckp.checkpoint(_fn, x, use_reentrant=False) #froward to CNN and checkpoint grads
            pred_bt.append(pred) #[B, 1]
        
        pred_video = torch.stack(pred_bt, dim=0).mean(dim=0)
        #pred_video = torch.mean(pred_bt, dim=1) #concat all chunks -> [B, T, C, H, W]
        #pred_video = self.head(pred_video)
            
        return pred_video
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)  #, weight_decay=self.Lambda_L2
        
        return [optimizer]
    
    def on_after_backward(self):
        norms = [p.grad.data.norm(2).item() for p in self.parameters() if p.grad is not None]
        avg_layer_norm = sum(norms) / len(norms)
        
        self.log("avg_grad_norm", avg_layer_norm)
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.test_probs.append(predicted_labels)
        self.test_targets.append(true_labels)
                
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
    def on_test_epoch_end(self):
        predicted_labels = torch.cat(self.test_probs).to(self.device)
        true_labels = torch.cat(self.test_targets).to(self.device).int()
        
        acc = torchmetrics.Accuracy('binary').to(self.device)(predicted_labels, true_labels)
        precision = torchmetrics.Precision('binary').to(self.device)(predicted_labels, true_labels)
        recall = torchmetrics.Recall('binary').to(self.device)(predicted_labels, true_labels)
        auc = torchmetrics.AUROC('binary').to(self.device)(predicted_labels, true_labels)
        f1Score = torchmetrics.F1Score('binary').to(self.device)(predicted_labels, true_labels)
        specificty = torchmetrics.Specificity("binary").to(self.device)(predicted_labels, true_labels)
        
        table = [
            ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", f"*{self.current_epoch}*", "0"]
        ]
        
        aucCurve = torchmetrics.ROC("binary").to(self.device)
        fpr, tpr, thhols = aucCurve(predicted_labels, true_labels)
        index = torch.argmax(tpr - fpr)
        th1 = thhols[index].item()
        
        accY = torchmetrics.Accuracy('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
        precisionY = torchmetrics.Precision('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
        recallY = torchmetrics.Recall('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
        specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(self.device)(predicted_labels, true_labels)
        f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
        table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", f"-{self.current_epoch}-", f"{(tpr - fpr)[index].item():.4f}"])
        
        df = pd.DataFrame(table, columns=["Threshold", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint", "J"])
        
        out_path = f"{self.logger.log_dir}/roc_points_epoch_{self.global_step}.csv"  
        df.to_csv(out_path, index=False)

class RARP_NVB_DINO_MultiTask_A6_Video(L.LightningModule):
    def __init__(
        self, 
        lr = 0.0001, 
        wd = 0.01,
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None,
        head_type:int = 0, #None = 0, linear = 1, Attn. Pooling = 2, TCN = 3, Replace Head =4
        chunks_loading:int = 50,
    ):
        super().__init__()
        
        self.lr = lr
        self.wd = wd
        self.chunks = chunks_loading
        self.head_type = head_type
            
        self.check_pt = True
        self.base_model = van.van_b2(pretrained = True, num_classes = 0)
        self.num_features_base_model = 512
            
        #self.base_model_wrapper = ModuleWrapper(self.base_model)
        
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        self.test_probs = []
        self.test_targets = []
        
                
        match(self.head_type):
            case 1:
                self.head = TemporalConvNet(self.num_features_base_model, [128, 8, 1])
            case _:
                self.head = None    
    
    def _shared_video_step(self, batch:list[torch.Tensor], val_step:bool = False):
        video, label = batch
        
        video = video.float() #[B, T, C, H, W]
        label = label.float() #[B]
  
        pred_video = self(video, val_step) #[B, T, 1]
        pred_video = pred_video.mean(dim=1) #[B, 1]
        pred_video = pred_video.flatten() #[B] to match labels shape
                    
        predicted_labels = torch.sigmoid(pred_video)
        loss = self.lossFN(pred_video, label)
        
        return loss, label, predicted_labels
        
    def forward(self, video:torch.Tensor, val_step:bool = True):
        B, T, C, H, W = video.shape
        chunk_T = self.chunks
        pred_bt = []
        
        def _fn(inp):
            pred = self.base_model(inp)
            return pred
                
        for t0 in tqdm(range(0, T, chunk_T), desc=f"Video Analysis in {chunk_T} chunk", leave=False): # Loop for each chunk
            t1 = min(T, t0 + chunk_T)
            x = video[:, t0:t1].reshape(-1, C, H, W).contiguous(memory_format=torch.channels_last) # reshape from [B, chunk_T, C, H, W] to [B*chunk_T, C, H, W] and make the tensor GPU optimization
            
            pred = torch_ckp.checkpoint(_fn, x) #froward to CNN and checkpoint grads
            pred_bt.append(pred.view(B, t1-t0, -1)) #apped to output array and reshape to [B, chunk_T, C, H, W]
            
        pred_video = torch.cat(pred_bt, dim=1) #concat all chunks -> [B, T, C, H, W]
        pred_video = self.head(pred_video)
            
        return pred_video
    
    def on_after_backward(self):
        norms = [p.grad.data.norm(2).item() for p in self.parameters() if p.grad is not None]
        avg_layer_norm = sum(norms) / len(norms)
        
        self.log("avg_grad_norm", avg_layer_norm)
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_video_step(batch, True)
        
        self.test_probs.append(predicted_labels)
        self.test_targets.append(true_labels)
                
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
    def on_test_epoch_end(self):
        predicted_labels = torch.cat(self.test_probs).to(self.device)
        true_labels = torch.cat(self.test_targets).to(self.device).int()
        
        acc = torchmetrics.Accuracy('binary').to(self.device)(predicted_labels, true_labels)
        precision = torchmetrics.Precision('binary').to(self.device)(predicted_labels, true_labels)
        recall = torchmetrics.Recall('binary').to(self.device)(predicted_labels, true_labels)
        auc = torchmetrics.AUROC('binary').to(self.device)(predicted_labels, true_labels)
        f1Score = torchmetrics.F1Score('binary').to(self.device)(predicted_labels, true_labels)
        specificty = torchmetrics.Specificity("binary").to(self.device)(predicted_labels, true_labels)
        
        table = [
            ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", f"*{self.current_epoch}*", "0"]
        ]
        
        aucCurve = torchmetrics.ROC("binary").to(self.device)
        fpr, tpr, thhols = aucCurve(predicted_labels, true_labels)
        index = torch.argmax(tpr - fpr)
        th1 = thhols[index].item()
        
        accY = torchmetrics.Accuracy('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
        precisionY = torchmetrics.Precision('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
        recallY = torchmetrics.Recall('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
        specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(self.device)(predicted_labels, true_labels)
        f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
        table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", f"-{self.current_epoch}-", f"{(tpr - fpr)[index].item():.4f}"])
        
        df = pd.DataFrame(table, columns=["Threshold", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint", "J"])
        
        out_path = f"{self.logger.log_dir}/roc_points_epoch_{self.global_step}.csv"  
        df.to_csv(out_path, index=False)
        
         
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd)  #, weight_decay=self.Lambda_L2
        
        return [optimizer]

class RARP_NVB_DINO_MultiTask_Unet(RARP_NVB_DINO_MultiTask):
    def _encoder_hool_fn(self, module, input, output):
        self.feature_maps.append(output)
    
    def _register_encoder_hooks(self):
        for layer in self.list_blocks:
            self.hooks.append(layer.register_forward_hook(self._encoder_hool_fn))
    
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy,
        momentum_teacher:float = 0.9995,
        lr:float = 1e-4,
        L1:float = None,
        L2:float = 0,
        std: float = None,
        mean: float = None,
        SoftAdptAlgo:int = 0,
        SoftAdptBeta:float = 0.1
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
        
        self.hooks = []
        self.feature_maps = []
        
        self.list_blocks = [
            self.student.backbone.block1[-1],
            self.student.backbone.block2[-1],
            self.student.backbone.block3[-1],
            self.student.backbone.block4[-2],
        ]
        
        self._register_encoder_hooks()
        
        self.decoder = DecoderUnet(1024)
                
    def forward(self, data, val_step:bool = True):
        self.feature_maps = []
        
        if val_step:
            data = data.float()
            dataTeacher, dataStudent = data, data
        else:
            data = [d.float() for d in data]
            dataTeacher, dataStudent = data[1:3], data

        TeacherDino = self.teacher_Features(dataTeacher)
        Student = self.student(dataStudent)
                       
        if not val_step:
            NumChunks = len(dataStudent)
            S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
            self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
        
        cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
        
        if not val_step:
            NumChunks = len(dataStudent)
            temp = []
            for f_maps in self.feature_maps:
                temp.append(torch.cat(f_maps.chunk(NumChunks)[1:3], dim=0))
            self.feature_maps = temp
            del temp
        
        self.feature_maps.append(cat_features)
        reconstructed_image = self.decoder(self.feature_maps)
        
        Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
        pred = self.clasiffier(Cont_Net)
                        
        return pred, (Student, TeacherDino), reconstructed_image

class RARP_NVB_DINO_MultiTask_MultiLabel(RARP_NVB_DINO_MultiTask):
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.BCEWithLogits, 
        momentum_teacher = 0.9995, 
        lr = 0.0001, 
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None, 
        SoftAdptAlgo = 0, 
        SoftAdptBeta = 0.1,
        Num_Lables = 2
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
        
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(1024, 128),
            torch.nn.Dropout(0.4),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(128, 8),
            torch.nn.Dropout(0.2),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(8, Num_Lables)
        )
        
        self.train_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
        self.val_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
        self.test_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
        self.f1ScoreTest = torchmetrics.F1Score('multilabel', num_labels=Num_Lables)

class RARP_NVB_DINO_MultiTask_v2(RARP_NVB_DINO_MultiTask):
    def __init__(
        self, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        momentum_teacher = 0.9995, 
        lr = 0.0001, 
        L1 = None, 
        L2 = 0, 
        std = None, 
        mean = None
    ):
        super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean)
            
        self.train_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.val_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.test_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.f1ScoreTest = torchmetrics.F1Score('multiclass', num_classes=4)
        
        self.lossFN = torch.nn.CrossEntropyLoss()
        
        self.softAdapt =  LossWeightedSoftAdapt(0.1) #NormalizedSoftAdapt(0.1)
        
        self.clasiffier = torch.nn.Sequential(
            torch.nn.Linear(1024, 512),
            torch.nn.Dropout(0.2),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(512, 256),
            #torch.nn.Dropout(0.4),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(256, 128),
            #torch.nn.Dropout(0.4),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(128, 8),
            #torch.nn.Dropout(0.2),
            torch.nn.SiLU(True),
            
            torch.nn.Linear(8, 4)
        )
        
    def _shared_step(self, batch, val_step:bool = False):
        img, label = batch
        
        prediction, features, new_image = self(img, val_step)
        StudentF, TeacherF = features
        
        predicted_labels = torch.softmax(prediction,dim=1)
        
        orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
        label = torch.cat([label for _ in range(len(TeacherF))], dim=0) if not val_step else label
        
        #DINO Loss
        loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
        #Clasificator
        loss_HL = self.lossFN(predicted_labels, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()

        if not val_step:
            if self.Lambda_L1 is not None:
                loss_l1 = 0
                for params in self.clasiffier.parameters(): # aqui
                    loss_l1 += torch.sum(torch.abs(params))
                loss_HL += self.Lambda_L1 * loss_l1
                
            if self.Lambda_L2 > 0:
                l2_reg = 0.0
                for param in self.clasiffier.parameters():
                    l2_reg += torch.norm(param, 2) ** 2
                loss_HL += self.Lambda_L2 * l2_reg
            
            self.loss_history["loss_DINO"].append(loss_Dino.item())
            self.loss_history["loss_Reconstruction"].append(loss_img.item())
            self.loss_history["loss_Binary"].append(loss_HL.item())
              
        loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
        
        return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)

class RARP_NVB_RN50_VAN_V2 (RARP_NVB_Model):
    # Define a hook function to capture the output
    def _hook_fn(self, module, input, output):
        self.last_conv_output = output
        
    def __init__(
        self, 
        PseudoEstimator:str = None, 
        threshold:float = 0.5, 
        InitWeight=None, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        PseudoLables:bool=True,
        HParameter = {},
        std: float = None,
        mean: float = None,
        **kwargs
    ) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
        self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
        
        #self.RARP_RestNet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50') 
        self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
        self.RARP_RestNet50.model.fc = torch.nn.Identity()  
        for parms in self.RARP_RestNet50.model.parameters():
            parms.requires_grad = False      
            
        self.RARP_RestNet50 = torch.nn.Sequential(*list(self.RARP_RestNet50.model.children())[:-2])
          
        self.decoder = Decoder()
        
        self.FeaturesLoss = FeatureAlignmentLoss()
        self.ReconstructionLoss = torch.nn.MSELoss()
    
        match (TypeError):
            case TypeLossFunction.ContrastiveLoss:
                self.lossFN = ContrastiveLoss()
            case _:
                pass            
        
        self.model = van.van_b2(pretrained = True, num_classes = 1)  
              
        self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04) 
        self.lr = getNVL(HParameter, "lr", 1.0E-4)
                
        print(f"lr={self.lr}, L1={self.Lambda_L1}")
        
        # Initialize a variable to store the output
        self.last_conv_output = None
        self.model.block4[-1].mlp.dwconv.register_forward_hook(self._hook_fn)
        

    def forward(self, data):
        dataTeacher, dataStudent, _ = data
        dataStudent = dataStudent.float()
        dataTeacher = dataTeacher.float()
        
        feature_Teacher = self.RARP_RestNet50 (dataTeacher)
        #feature_Student = self.model.forward_features(dataStudent)
        pred = self.model(dataStudent)  
        feature_Student = self.last_conv_output
        
        #feature_Teacher = torch.nn.functional.adaptive_avg_pool1d(feature_Teacher, feature_Student.size(-1)) #Re-size output vector to macth VAN 
        
        cat_features = (feature_Student + feature_Teacher) / 2
                        
        reconstructed_image = self.decoder(cat_features) 
        
        feature_Student = torch.nn.functional.adaptive_avg_pool2d(feature_Student, (1, 1)).flatten(1)
        feature_Teacher = torch.nn.functional.adaptive_avg_pool2d(feature_Teacher, (1, 1)).flatten(1)
                
        return pred, (feature_Student, feature_Teacher), reconstructed_image
    
    def _shared_step(self, batch):
        img, label = batch
        label = label.float()
        orignalImg = img[2].float()
                
        prediction, features, new_image = self(img)
        
        prediction = prediction.flatten()
        predicted_labels = torch.sigmoid(prediction)
        
        #cosine similarity
        loss_cosine = self.FeaturesLoss(features[0], features[1])
        #Clasificator
        loss_hl = self.lossFN(prediction, label)
        #Reconstruction
        loss_img = self.ReconstructionLoss(new_image, orignalImg)
        loss_img = loss_img.float()
        
        loss = loss_cosine + loss_hl + loss_img
        
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.model.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
            
        return loss, label, predicted_labels, (loss_cosine, loss_hl, loss_img, new_image)
    
    def _denormalize(self, tensor:torch.Tensor):
        # Move mean and std to the same device as the input tensor
        mean = self.mean_IMG.to(tensor.device)
        std = self.std_IMG.to(tensor.device)
        return tensor * std + mean
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("train_loss_cosSim", losses[0], on_epoch=True, on_step=False)
        self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
        
        if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
        self.log("val_loss_cosSim", losses[0], on_epoch=True, on_step=False)
        self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
        if self.mean_IMG is not None and self.std_IMG is not None:
            imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
            imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
            grid = torchvision.utils.make_grid(imgReconstruction)
            self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
        
        return optimizer
                
class RARP_NVB_ResNet50_VAN(RARP_NVB_Model):
    def __init__(
        self, 
        PseudoEstimator:str = None, 
        threshold:float = 0.5, 
        InitWeight=None, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        PseudoLables:bool=True,
        HParameter = {},
        **kwargs
    ) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)
        
        self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
        self.threshold = threshold
        self.PseudoLables = PseudoLables
        
        match (TypeError):
            case TypeLossFunction.ContrastiveLoss:
                self.lossFN = ContrastiveLoss()
            case _:
                pass            
            
        for parms in self.RARP_RestNet50.model.parameters():
            parms.requires_grad = False
            
        self.model = van.van_b2(pretrained = True, num_classes = 1)        
        self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04) 
        self.lr = getNVL(HParameter, "lr", 1.0E-4)
        self.W_Apha = getNVL(HParameter, "Alpha", 0.60)
        self.W_Beta = 1 - getNVL(HParameter, "Alpha", 0.60)
        self.thao_KD = getNVL(HParameter, "Thao", 5)
        
        print(f"A={self.W_Apha}; B={self.W_Beta}, T={self.thao_KD}, lr={self.lr}, L1={self.Lambda_L1}")
        
        #self.Lambda_L1 = None
        #self.scheduler = True
        #self.lr = 1.74E-2
        
    def forward(self, data):
        #dataTeacher, dataStudent = data
        if isinstance(data, tuple):
            dataTeacher, dataStudent = data
        elif isinstance(data, torch.Tensor):
            dataTeacher, dataStudent = (data, data)
        RN50Pred = torch.sigmoid(self.RARP_RestNet50(dataTeacher).flatten())
        PseudoLabels = (RN50Pred > self.threshold) * 1.0
        
        dataStudent = dataStudent.float()
        pred = self.model(dataStudent)
                
        return pred, PseudoLabels, RN50Pred
        
    def _shared_step(self, batch):
        img, label = batch
        if self.InitWeight is not None:
            self.lossFN.weight = self.InitWeight[label]
        
        label = label.float()
        prediction, PseudoLabels, predictionRN50 = self(img) #.flatten()
        prediction = prediction.flatten()
        predicted_labels = torch.sigmoid(prediction)
         #
        # L[B] => BCEWithLogits
        # L[C] => ContrastiveLoss
        #Loss L[1] = L[B_HL] + L[B_PL]
        #Loss L[2] = L[C_HL]
        #Loss L[3] = L[C_HL] + L[C_PL]
        #Loss L[4] = L[C_y_hat]
        #Loss L[5] = L[KLD_PL] + L[B_PL]
        #Loss L[6] = L[MSE_PL] + L[B_PL]
        #Loss L[7] = L[B_HL] + L[BCE_PL]
        #Loss L[8] = FocalLoss*L[JSD]
        #Nuveo
        #thao_KD = 5.0
        #w_alpha, w_beta = (0.60, 0.40)
        
        #L[7]:
        softTeacher = torch.sigmoid(predictionRN50/self.thao_KD)
        softStudent = torch.sigmoid(prediction/self.thao_KD)
        
        loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
        loss_hl = self.lossFN(prediction, label)
        
        loss = self.W_Apha * loss_hl + self.W_Beta * loss_sl #/ (self.thao_KD ** 2)
        
        #loss = w_alpha * self.lossFN(prediction, label) + w_beta * self.lossFN(prediction, PseudoLabels) #L[1]
        #loss = w_alpha * (torch.nn.KLDivLoss()(softStudent, softTeacher) * (thao_KD ** 2)) + w_beta * self.lossFN(prediction, PseudoLabels) #L[5]
        #loss = w_alpha * torch.nn.functional.mse_loss(softStudent, softTeacher) + w_beta * self.lossFN(prediction, PseudoLabels) #L[6]
        #loss = self.lossFN(predictionRN50, predicted_labels, label) #L[2]
        #loss = self.lossFN(predictionRN50, predicted_labels, label) + self.lossFN(predictionRN50, predicted_labels, PseudoLabels) #L[3]
        #y_hat = (PseudoLabels != label) * 1
        #loss = self.lossFN(predictionRN50, predicted_labels, y_hat) #L[4]
        
        if self.Lambda_L1 is not None:
            loss_l1 = 0
            for params in self.model.parameters():
                loss_l1 += torch.sum(torch.abs(params))
            loss += self.Lambda_L1 * loss_l1
        
        return loss, (PseudoLabels if self.PseudoLables else label), predicted_labels, (loss_hl, loss_sl)
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)

        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(predicted_labels, true_labels)
        
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        self.log("train_lh", losses[0], on_step=False, on_epoch=True)
        self.log("train_ld", losses[1], on_step=False, on_epoch=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels, losses = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(predicted_labels, true_labels)
        
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_lh", losses[0], on_step=False, on_epoch=True)
        self.log("val_ld", losses[1], on_step=False, on_epoch=True)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, _ = self._shared_step(batch)
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
        if self.scheduler:
            scheduler = CosineAnnealingLR(
                optimizer, 
                max_epochs=150,
                warmup_epochs=8,
                warmup_start_lr=0.001,
                eta_min=0.0001
            )
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    #"monitor": "val_loss",
                }
            }
        return optimizer

class RARP_NVB_SSL_RestNet50_Deep(RARP_NVB_ResNet50_VAN):
    def __init__(
        self, 
        PseudoEstimator: str = None, 
        threshold: float = 0.5, 
        InitWeight=None, 
        TypeLoss=TypeLossFunction.CrossEntropy, 
        PseudoLables: bool = True,
        **kwargs
    ) -> None:
        super().__init__(None, threshold, InitWeight, TypeLoss, PseudoLables, **kwargs)
        
        self.RARP_RestNet50 = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        #self.model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
        tempFC_ft = 2048
        self.model.fc = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(in_features=tempFC_ft, out_features=128),
            torch.nn.SiLU(True),
            torch.nn.Linear(128, 8),
            torch.nn.SiLU(True),
            torch.nn.Linear(8, 1)
        )
        
        for parms in self.RARP_RestNet50.model.parameters():
            parms.requires_grad = False

class RARP_NVB_MobileNetV2(RARP_NVB_Model):#class RARP_NVB_MobileNetV2(RARP_NVB_Model_BCEWithLogitsLoss):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

class RARP_NVB_EfficientNetV2(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
        super().__init__(InitWeight, TypeLoss, **kwargs)

        self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
        tempFC_ft = self.model.classifier[1].in_features 
        self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
class RARP_NVB_Vit_b_16(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
        super().__init__(InitWeight, TypeLoss)
        
        self.model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        tempFC_ft = self.model.heads[0].in_features 
        self.model.heads[0] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
        
class RARP_NVB_DenseNet169(RARP_NVB_Model):
    def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
        super().__init__(InitWeight, TypeLoss)
        
        self.model = torchvision.models.densenet169(weights=torchvision.models.DenseNet169_Weights.DEFAULT)
        inFeatures = self.model.classifier.in_features
        self.model.classifier = torch.nn.Linear(in_features=inFeatures, out_features=1)

class RARP_NVB_RestNet50_old(L.LightningModule):
    def __init__(self, InitialWeigth = 1) -> None:
        super().__init__()

        self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
        tempFC_ft = self.model.fc.in_features 
        self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4) 
        self.lossFN = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([InitialWeigth]))

        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')

    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def training_step(self, batch, batch_idx):
        img, label = batch
        label = label.float()
        prediction = self(img)[:,0]
        loss = self.lossFN(prediction, label)

        self.log("Train Loss", loss)
        self.log("Step Train Acc", self.train_acc(torch.sigmoid(prediction), label.int()))

        return loss
    
    def on_train_epoch_end(self):
        self.log("Train Acc", self.train_acc.compute())

    def validation_step(self, batch, batch_idx):
        img, label = batch
        label = label.float()
        prediction = self(img)[:,0]
        loss = self.lossFN(prediction, label)

        self.log("Val Loss", loss)
        self.log("Step Val Acc", self.val_acc(torch.sigmoid(prediction), label.int())) #train_acc

        return loss
    
    def on_validation_epoch_end(self):
        self.log("Val Acc", self.val_acc.compute())

    def configure_optimizers(self):
        return [self.optimizer]