import warnings
warnings.simplefilter("ignore")
import math
from typing import Any, Union
import torch
import torch.utils.checkpoint as torch_ckp
import torchvision
import torchmetrics
import torchmetrics.classification
import lightning as L
from enum import Enum
import timm
import van
import numpy as np
from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt
from noah import NOAH
import piq
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import pandas as pd
from pathlib import Path
from I3D_RestNet50 import I3DResNet50
def js_divergence_sigmoid(p_logits, q_logits):
p_probs = torch.sigmoid(p_logits)
q_probs = torch.sigmoid(q_logits)
m = 0.5 * (p_probs + q_probs)
bce_p_m = torch.nn.functional.binary_cross_entropy(m, p_probs, reduction='none')
bce_q_m = torch.nn.functional.binary_cross_entropy(m, q_probs, reduction='none')
js_div = 0.5 * (bce_p_m + bce_q_m)
return js_div
def getNVL(obj:dict, key, default):
return default if obj.get(key) is None else obj.get(key)
class Decoder (torch.nn.Module):
def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], *args, **kwargs):
super().__init__(*args, **kwargs)
assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
self.Activation = torch.nn.ReLU#torch.nn.GELU
self.input_channels = input_channels
blocks = []
inCh = input_channels
blocks.append(torch.nn.Conv2d(inCh, inCh, kernel_size=3, stride=1, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(inCh))
blocks.append(self.Activation())
for i, outCh in enumerate(hidden_channels):
blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=4, stride=2, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(outCh))
blocks.append(self.Activation())
blocks.append(torch.nn.Conv2d(outCh, outCh, kernel_size=3, stride=1, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(outCh))
blocks.append(self.Activation())
#blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=3, stride=2, padding=1, output_padding=1))
#blocks.append(torch.nn.BatchNorm2d(outCh))
#blocks.append(self.Activation())
inCh = outCh
blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=4, stride=2, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(output_channels))
blocks.append(self.Activation())
blocks.append(torch.nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(output_channels))
blocks.append(self.Activation())
#blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
self.decoder = torch.nn.Sequential(*blocks)
def forward(self, x):
x = self.decoder(x)
return x
class DynamicDecoder(torch.nn.Module):
def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64]):
super(DynamicDecoder, self).__init__()
# Ensure the number of hidden channels matches the number of blocks
assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
layers = []
in_channels = input_channels
# Loop to create the decoder blocks
for out_channels in hidden_channels:
layers.append(torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
layers.append(torch.nn.BatchNorm2d(out_channels))
layers.append(torch.nn.ReLU(inplace=True))
in_channels = out_channels
# Final layer to get the output image
layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
#layers.append(torch.nn.Sigmoid()) # To get pixel values between 0 and 1
# Combine all layers into a Sequential module
self.decoder = torch.nn.Sequential(*layers)
def forward(self, x):
return self.decoder(x)
class DecoderUnet(torch.nn.Module):
def __init__(self, input_channels=2048, output_channels=3):
super().__init__()
self.dropout = torch.nn.Dropout2d(0.4)
self.upConv_0 = torch.nn.ConvTranspose2d(input_channels, 512, kernel_size=2, stride=2)
self.decoder_0 = self._conv_block(1024, 512)
self.upConv_1 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2)
self.decoder_1 = self._conv_block(640, 320)
self.upConv_2 = torch.nn.ConvTranspose2d(320, 128, kernel_size=2, stride=2)
self.decoder_2 = self._conv_block(256, 128)
self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder_3 = self._conv_block(128, 64)
self.upConv_4 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.decoder_4 = self._conv_block(32, 16)
self.enc3_upsample = torch.nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
self.enc2_upsample = torch.nn.ConvTranspose2d(320, 320, kernel_size=2, stride=2)
self.enc1_upsample = torch.nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
self.enc0_upsample = torch.nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
self.last_conv = torch.nn.Conv2d(16, output_channels, kernel_size=1)
def _conv_block(self, in_ch, out_ch):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
torch.nn.SiLU(inplace=True),
torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
torch.nn.SiLU(inplace=True),
)
def forward(self, x):
encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = x
decoder_l3 = self.upConv_0(btlneck)
encoder_l3 = self.enc3_upsample(encoder_l3)
decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
decoder_l3 = self.decoder_0(decoder_l3)
decoder_l3 = self.dropout(decoder_l3)
decoder_l2 = self.upConv_1(decoder_l3)
encoder_l2 = self.enc2_upsample(encoder_l2)
decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
decoder_l2 = self.decoder_1(decoder_l2)
decoder_l2 = self.dropout(decoder_l2)
decoder_l1 = self.upConv_2(decoder_l2)
encoder_l1 = self.enc1_upsample(encoder_l1)
decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
decoder_l1 = self.decoder_2(decoder_l1)
decoder_l1 = self.dropout(decoder_l1)
decoder_l0 = self.upConv_3(decoder_l1)
encoder_l0 = self.enc0_upsample(encoder_l0)
decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1)
decoder_l0 = self.decoder_3(decoder_l0)
decoder_last = self.upConv_4(decoder_l0)
decoder_last = self.decoder_4(decoder_last)
return self.last_conv(decoder_last)
class ModelsList(Enum):
RestNet50 = 1
DenseNet169 = 2
Efficientnet_b0 = 3
MobileNetV2 = 4
Inception3 = 5
ResNeXt_50_32x4d = 6
RestNet50_Droput = 7
class TypeLossFunction(Enum):
CrossEntropy = 0
BCEWithLogits = 1
HingeLoss = 2
FocalLoss = 3
ContrastiveLoss = 4
class FeatureAlignmentLoss(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, F_student, F_teacher):
fS = torch.nn.functional.normalize(F_student, p=2, dim=-1)
fT = torch.nn.functional.normalize(F_teacher, p=2, dim=-1)
cos_Sim = torch.sum(fS * fT, dim=-1)
loss = 1 - cos_Sim
return loss.mean()
class CosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
max_epochs: int,
warmup_start_lr: float = 0.00001,
eta_min: float = 0.00001,
last_epoch: int = -1,
):
"""
Args:
optimizer (torch.optim.Optimizer):
最適化手法インスタンス
warmup_epochs (int):
linear warmupを行うepoch数
max_epochs (int):
cosine曲線の終了に用いる 学習のepoch数
warmup_start_lr (float):
linear warmup 0 epoch目の学習率
eta_min (float):
cosine曲線の下限
last_epoch (int):
cosine曲線の位相オフセット
学習率をmax_epochsに至るまでコサイン曲線に沿ってスケジュールする
epoch 0からwarmup_epochsまでの学習曲線は線形warmupがかかる
https://pytorch-lightning-bolts.readthedocs.io/en/stable/schedulers/warmup_cosine_annealing.html
"""
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.warmup_start_lr = warmup_start_lr
self.eta_min = eta_min
super().__init__(optimizer, last_epoch, verbose=False)
return None
def get_lr(self):
if self.last_epoch == 0:
return [self.warmup_start_lr] * len(self.base_lrs)
if self.last_epoch < self.warmup_epochs:
return [
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
if self.last_epoch == self.warmup_epochs:
return self.base_lrs
if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
return [
group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
/ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)))
* (group["lr"] - self.eta_min)
+ self.eta_min
for group in self.optimizer.param_groups
]
class ContrastiveLoss(torch.nn.Module):
def __init__(self, margin=1.0, distance:int = 0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.TypeDistance = distance
def forward(self, output1, output2, label):
# Calcula la distancia euclidiana entre las dos salidas
Distance = torch.nn.functional.pairwise_distance(output1, output2)
# Calcula la pérdida contrastiva
loss_contrastive = torch.mean(
(1 - label) * torch.pow(Distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - Distance, min=0.0), 2)
)
return loss_contrastive
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self.hook()
def hook(self):
def forward_hook(module, Input, Output):
self.activations = Output
def backward_hook(module, grad_in, grad_out):
self.gradients = grad_out[0]
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_backward_hook(backward_hook)
def generate_cam(self, data:torch.Tensor, target_class):
self.model.zero_grad()
data.requires_grad= True
output = self.model(data).flatten()
loss = torch.nn.functional.binary_cross_entropy_with_logits(output, target_class)
loss.backward()
pooled_Grad = torch.mean(self.gradients, dim=[0, 2, 3])
for i in range (pooled_Grad.size(0)):
self.activations[:, i, :, :] *= pooled_Grad[i]
cam = torch.mean(self.activations, dim=1).squeeze()
cam = np.maximum(cam.detach().numpy(), 0)
cam = (cam - cam.min()) / (cam.max() - cam.min())
return cam
class UNet_RN18(torch.nn.Module):
def _conv_block(self, in_ch, out_ch):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
torch.nn.SiLU(inplace=True),
torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
torch.nn.SiLU(inplace=True),
)
def _hook_fn(self, module, input, output):
self.feature_maps.append(output)
def _register_encoder_hooks(self):
for layer in self.list_blocks:
self.hooks.append(layer.register_forward_hook(self._hook_fn))
def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hooks = []
self.feature_maps = []
self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
self.encoder_base.fc = torch.nn.Identity()
#for parms in self.encoder_base.parameters():
# parms.requires_grad = False
self.list_blocks = [
self.encoder_base.conv1,
self.encoder_base.layer1,
self.encoder_base.layer2,
self.encoder_base.layer3,
self.encoder_base.layer4
]
self._register_encoder_hooks()
self.dropout = torch.nn.Dropout2d(0.4)
self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.decoder_0 = self._conv_block(512, 256) #0
self.decoder_1 = self._conv_block(256, 128) #1
self.decoder_2 = self._conv_block(128, 64) #2
self.decoder_3 = self._conv_block(32 + 64, 32) #3
self.decoder_extra = self._conv_block(16, 16)
self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1)
def forward(self, x):
self.feature_maps = []
_ = self.encoder_base(x) # forward to encoder and call hooks
encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps
decoder_l3 = self.upConv_0(btlneck)
decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
decoder_l3 = self.decoder_0(decoder_l3)
decoder_l3 = self.dropout(decoder_l3)
decoder_l2 = self.upConv_1(decoder_l3)
decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
decoder_l2 = self.decoder_1(decoder_l2)
decoder_l2 = self.dropout(decoder_l2)
decoder_l1 = self.upConv_2(decoder_l2)
decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
decoder_l1 = self.decoder_2(decoder_l1)
decoder_l1 = self.dropout(decoder_l1)
decoder_l0 = self.upConv_3(decoder_l1)
decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1)
decoder_l0 = self.decoder_3(decoder_l0)
decoder_last = self.upConv_extra(decoder_l0)
decoder_last = self.decoder_extra(decoder_last)
return self.last_conv(decoder_last)
class RARP_NVB_ROI_Mask_Unet(L.LightningModule):
def __init__(self,*args, **kwargs):
super().__init__(*args, **kwargs)
self.model = UNet_RN18(in_channels=3, out_channels=1)
self.lr = 1E-4
self.Lambda_L1 = None
self.lossFN = torch.nn.BCEWithLogitsLoss()
self.train_IoU = torchmetrics.classification.BinaryJaccardIndex()
self.val_IoU = torchmetrics.classification.BinaryJaccardIndex()
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def _shared_step(self, batch, val_step:bool = True):
img, mask = batch
mask = mask.float()
mask = mask.unsqueeze(1)
prediction = self(img)
loss = self.lossFN(prediction, mask)
predicted_labels = torch.sigmoid(prediction)
if not val_step:
if self.Lambda_L1 is not None:
loss_l1 = 0
for name, params in self.model.named_parameters():
if "decoder" in name or "upConv" in name:
loss_l1 += torch.norm(params, p=1)
loss += self.Lambda_L1 * loss_l1
return loss, mask, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch, False)
self.train_IoU.update(predicted_labels, true_labels)
self.log("train_loss", loss, on_epoch=True)
self.log("train_acc_IoU", self.train_IoU, on_epoch=True, on_step=False)
return loss
def on_after_backward(self):
total_norm = 0.0
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.log("grad_norm", total_norm)
if total_norm < 1e-8:
self.log("grad_warning", "Vanishing gradient suspected!")
def on_train_epoch_start(self):
for parms in self.model.encoder_base.parameters():
parms.requires_grad = (self.current_epoch % 2 == 0)
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.val_IoU.update(predicted_labels, true_labels)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_acc_IoU", self.val_IoU, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
return [optimizer]
class RARP_NVB_ResNet50_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.resnet50()
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
pred = self.model.fc(Cont_Net)
return pred, featureMap
class RARP_NVB_VAN_CAM(L.LightningModule): #TODO
def __init__(self) -> None:
super().__init__()
self.model = van.van_b2(pretrained = True)
tempFC_ft = self.model.head.in_features
self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data, label:torch.Tensor):
cams = GradCAM(self.model, self.model.block4[-1])
featureMap = cams.generate_cam(data, label)
pred = self.model(featureMap)
return pred, featureMap
class RARP_NVB_ResNet18_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.resnet18()
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1))
Cont_Net = torch.flatten(Cont_Net)
pred = self.model.fc(Cont_Net)
return pred, featureMap
class RARP_NVB_MobileNetV2_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.mobilenet_v2()
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = self.model.features
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
pred = self.model.classifier(Cont_Net)
return pred, featureMap
class RARP_NVB_EfficientNetV2_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = self.model.features
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
pred = self.model.classifier(Cont_Net)
return pred, featureMap
class RARP_NVB_Model_BCEWithLogitsLoss(L.LightningModule):
def __init__(self, x=None) -> None:
super().__init__()
self.model = None
self.lossFN = torch.nn.BCEWithLogitsLoss() #pos_weight=torch.tensor([2.73])
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1Score = torchmetrics.F1Score('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
def forward(self, data):
data = data.float()
return self.model(data)
def _shared_step(self, batch):
img, label = batch
label = label.float()
pred = self.forward(img).flatten()
loss = self.lossFN(pred, label)
predicted_labels = torch.sigmoid(pred)
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss)
self.val_acc(predicted_labels, true_labels)
self.f1Score(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc(predicted_labels, true_labels)
self.f1ScoreTest(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
return [optimizer]
class RARP_NVB_FOCAL_loss(torch.nn.Module):
def __init__(self, alpha: float = 0.25, gamma: float = 2, reduction: str = "mean"):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torchvision.ops.focal_loss.sigmoid_focal_loss(input, target, self.alpha, self.gamma, self.reduction)
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
# Convert targets to one-hot encoding
targets = torch.nn.functional.one_hot(targets, num_classes=inputs.size(1)).float()
# Compute softmax over the inputs
probs = torch.nn.functional.softmax(inputs, dim=1)
log_probs = torch.nn.functional.log_softmax(inputs, dim=1)
# Compute the focal loss components
focal_weight = (1 - probs) ** self.gamma
loss = -self.alpha * focal_weight * targets * log_probs
# Apply reduction method
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
#TODO
class RARP_NVB_MultiClassModel(L.LightningModule):
def __init__(self,
InitWeight = torch.tensor([1,1]),
schedulerLR:bool=False,
lr:float = 1e-4,
Model:torch.nn.Module = None,
Num_Classes:int = 2,
L1:float = 1.31E-04,
L2:float = 0
) -> None:
super().__init__()
self.model = Model
self.lossFN = FocalLoss() #torch.nn.CrossEntropyLoss()
self.InitWeight = InitWeight
self.scheduler = schedulerLR
self.lr = lr
self.Lambda_L1 = L1
self.Lambda_L2 = L2
self.num_classes = Num_Classes
self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
self.test_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
self.f1ScoreTest = torchmetrics.F1Score("multiclass", num_classes=Num_Classes)
self.val_loop = False
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def _shared_step(self, batch):
img, label = batch
prediction = self(img)
predicted_labels = torch.softmax(prediction, dim=1)
loss = self.lossFN(prediction, label)
if self.Lambda_L1 is not None and not self.val_loop:
loss_l1 = 0
for params in self.model.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
self.val_loop = False
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
self.val_loop = True
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
self.val_loop = True
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
if self.scheduler:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
else:
return [optimizer]
class RARP_NVB_Model(L.LightningModule):
def __init__(self,
InitWeight = torch.tensor([1,1]),
typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy,
schedulerLR:bool=True,
lr:float = 1e-4
) -> None:
super().__init__()
self.model = None
self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
self.InitWeight = InitWeight
self.scheduler = schedulerLR
self.lr = lr
self.Lambda_L1 = None #1.31E-04 #
self.Lambda_L2 = 0
print (f"LR= {self.lr}, L1= {self.Lambda_L1}")
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
#self.f1Score = torchmetrics.F1Score('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def _shared_step(self, batch):
img, label = batch
if self.InitWeight is not None:
self.lossFN.weight = self.InitWeight[label]
label = label.float()
prediction = self(img)[:,0] #.flatten()
loss = self.lossFN(prediction, label)
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
predicted_labels = torch.sigmoid(prediction)
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
#self.f1Score.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("hp_metric", self.val_acc, on_step=False, on_epoch=True)
#self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
if self.scheduler:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
else:
return optimizer
class RARP_Ensemble(L.LightningModule):
def __init__(self, ListModels,
InitWeight = torch.tensor([1,1]),
typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy,
schedulerLR:bool=False,
lr:float = 1e-4) -> None:
super().__init__()
self.ListModels = ListModels
#for m in self.ListModels:
# m.freeze()
input_p = len(self.ListModels)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(in_features=input_p, out_features=128),
torch.nn.SiLU(),
#torch.nn.Dropout(0.1), #0.5
torch.nn.Linear(128, 1),
torch.nn.Sigmoid()
)
self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCELoss(reduction="sum")
self.InitWeight = InitWeight
self.scheduler = schedulerLR
self.lr = lr
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1Score = torchmetrics.F1Score('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
def forward(self, data):
data = data.float()
p = [m(data) for m in self.ListModels]
p = torch.cat(p, dim=1)
x = self.classifier(p)
return x
def _shared_step(self, batch):
img, label = batch
if self.InitWeight is not None:
self.lossFN.weight = self.InitWeight[label]
label = label.float()
prediction = self(img)[:,0] #.flatten()
loss = self.lossFN(prediction, label)
predicted_labels = prediction
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.f1Score.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
if self.scheduler:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, 1e-8, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
else:
return [optimizer]
class RARP_NVB_ResNet18(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_DaVit(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = timm.create_model("davit_small.msft_in1k", pretrained=True, num_classes=1)
class RARP_NVB_EfficientNetV2_Deep(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
self.model.classifier.append(torch.nn.SiLU(True))
self.model.classifier.append(torch.nn.Linear(128, 8))
self.model.classifier.append(torch.nn.SiLU(True))
self.model.classifier.append(torch.nn.Linear(8, 1))
class RARP_NVB_ResNet50_Deep(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(in_features=tempFC_ft, out_features=128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
def forward(self, img):
img = img.float()
pred = self.model(img)
return pred
class RARP_NVB_ResNet50_Deep_OPTuna(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, output_dims=[128, 8], dropuot=0.2, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
layers = []
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
inputDim = self.model.fc.in_features
for outputDim in output_dims:
layers.append(torch.nn.Linear(inputDim, outputDim))
layers.append(torch.nn.SiLU(True))
layers.append(torch.nn.Dropout(dropuot))
inputDim = outputDim
layers.append(torch.nn.Linear(inputDim, 1))
self.model.fc = torch.nn.Sequential(*layers)
def forward(self, img):
img = img.float()
pred = self.model(img)
return pred
class RARP_NVB_ResNet50_V2(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(in_features=tempFC_ft, out_features=128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8)
)
self.model.fc2 = torch.nn.Linear(8 + InputNeurons, 1)
def forward(self, data):
img, extra = data
img = img.float()
x = torch.nn.functional.silu(self.model(img), True)
extradata = torch.concat((x, extra), dim=1)
pred = self.model.fc2(extradata)
return pred
class RARP_NVB_ResNet50_V3(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
self.model.extraFC = torch.nn.Linear(InputNeurons, 128)
self.model.fc2 = torch.nn.Linear(256, 1)
def forward(self, data):
img, extra = data
img = img.float()
x = torch.nn.functional.silu(self.model(img), True)
y = torch.nn.functional.silu(self.model.extraFC(extra), True)
x = torch.concat((x, y), dim=1)
pred = self.model.fc2(x)
return pred
class RARP_NVB_ResNet50_V1(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features + InputNeurons
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.Decoder = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data):
img, extra = data
img = img.float()
featureMap = self.Decoder(img)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
Cont_Net = torch.concat((Cont_Net, extra), dim=1)
pred = self.model.fc(Cont_Net)
return pred
class RARP_NVB_VAN(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = van.van_b2(pretrained = True)
tempFC_ft = self.model.head.in_features
self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_ResNet50(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=512, bottleneck=256, n_layers=3, norm_last_layer=False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.activationFN = torch.nn.GELU()
if n_layers == 1:
self.mlp = torch.nn.Linear(in_dim, bottleneck)
else:
layers = [torch.nn.Linear(in_dim, hidden_dim)]
#layers.append(torch.nn.BatchNorm1d(hidden_dim)) # Add
layers.append(self.activationFN)
layers.append(torch.nn.Dropout(0.30))
for _ in range(n_layers - 2):
layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
layers.append(self.activationFN)
layers.append(torch.nn.Linear(hidden_dim, bottleneck))
self.mlp = torch.nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = torch.nn.utils.weight_norm(
torch.nn.Linear(bottleneck, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = torch.nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
class RARP_NVB_DINO_Wrapper(torch.nn.Module):
def __init__(self, backbone:torch.nn.Module, new_head:torch.nn.Module, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.backbone = backbone
self.head = new_head
def forward(self, x):
if isinstance(x, list):
n_crops = len(x)
concatCrops = torch.cat(x, dim=0)
else:
concatCrops = x
n_crops = 1
embedding = self.backbone(concatCrops)
logitis = self.head(embedding)
chunks = logitis.chunk(n_crops)
return chunks
class RARP_NVB_DINO_Loss(torch.nn.Module):
def __init__(self, out_dim:int, teacher_Thao:float = 0.04, student_Thao:float = 0.1, center_momentum:float = 0.9, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.S_Thao = student_Thao
self.T_Thao = teacher_Thao
self.C_Momentum = center_momentum
self.register_buffer("center", torch.zeros(1, out_dim))
def forward(self, s_Output, t_Output):
sTemp = [s / self.S_Thao for s in s_Output]
tTemp = [(t - self.center) / self.T_Thao for t in t_Output]
studentSM = [torch.nn.functional.log_softmax(s, dim=-1) for s in sTemp]
teacherSM = [torch.nn.functional.softmax(t, dim=-1).detach() for t in tTemp]
total_loss = 0
n_loss_terms = 0
for t_ix, t in enumerate(teacherSM):
for s_ix, s in enumerate(studentSM):
if (t_ix == s_ix) and (len(teacherSM) > 1):
continue
loss = torch.sum(-t * s, dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
self.update_center(t_Output)
return total_loss
@torch.no_grad()
def update_center(self, t_output):
b = torch.cat(t_output).mean(dim=0, keepdim=True)
self.center = self.center * self.C_Momentum + b * (1 - self.C_Momentum)
class RARP_NVB_DINO_RestNet50_Deep(L.LightningModule):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
L1:float = None,
L2:float = 0,
) -> None:
super().__init__()
self.lr = lr
self.Lambda_L1 = L1
self.Lambda_L2 = L2
self.threshold = threshold
self.momentum_teacher = momentum_teacher
self.out_dim = 512
self.in_dim = 2048
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
self.teacher_Labels = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
self.student = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) #torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
self.teacher_Features = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
self.student.fc = torch.nn.Identity()
self.teacher_Features.fc = torch.nn.Identity()
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
for parms in self.teacher_Labels.model.parameters():
parms.requires_grad = False
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, 0.04, 0.1, momentum_teacher)
self.lossFN_KD = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
#self.lossFH_KLDiv = torch.nn.KLDivLoss(reduction="batchmean")
self.clasiffier = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(self.out_dim, 128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
def forward(self, data, val_step:bool = True):
if val_step:
data = data.float()
dataClassificator, dataTeacher, dataStudent = data, data, data
else:
data = [d.float() for d in data]
dataClassificator, dataTeacher, dataStudent = data[0], data[1:3], data
TeacherDino = self.teacher_Features(dataTeacher)
TeacherLabels = self.teacher_Labels(dataClassificator)
Student = self.student(dataStudent)
# es se evaluan todas las salidas del estuidaitne
#if isinstance(dataStudent, list):
# #index = np.random.randint(0, len(dataStudent))
# temp = self.student(dataStudent)
# CatS_Classifier = torch.cat(temp, dim=0)
# meanS_Classifier = torch.zeros(self.in_dim)
# for dataS in temp:
# meanS_Classifier += dataS
# #S_Classifier = self.student(dataStudent[index])
# S_Classifier = meanS_Classifier / len(dataStudent)
#else:
# S_Classifier = Student
Cont_Net = torch.cat(Student, dim=0)
pred = self.clasiffier(Cont_Net)
if not val_step:
TeacherLabels = [self.teacher_Labels(dataClassificator) for _ in range(len(dataStudent))]
TeacherLabels = torch.cat(TeacherLabels, dim=0)
TeacherLabelsPred = torch.sigmoid(TeacherLabels.flatten())
PseudoLabels = (TeacherLabelsPred > self.threshold) * 1.0
return (pred.flatten(), PseudoLabels, TeacherLabels.flatten()), (TeacherDino, Student)
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
if not val_step:
label = torch.cat([label for _ in range(len(img))], dim=0)
label = label.float()
KD_Prediction, DINO_Loss = self(img, val_step)
TeacherF, StudentF = DINO_Loss
prediction, PseudoLabels, teacherOutputs = KD_Prediction
predicted_labels = torch.sigmoid(prediction)
##verstion 1
W_Alpha, W_Beta = (1, 0.5)#(1, 0.5)
loss = W_Alpha * self.lossFN_KD(prediction, PseudoLabels) + W_Beta * self.lossFN_KD(prediction, label)
#version 2
#thao_KD = 1#5.0
#W_Alpha, W_Beta = (0.6, 0.4)
#softTeacher = torch.sigmoid(teacherOutputs/thao_KD)
#softStudent = torch.sigmoid(prediction/thao_KD)
#loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
#loss_hl = self.lossFN_KD(prediction, label)
#loss = W_Alpha * loss_hl + W_Beta * loss_sl
#loss = W_Alpha * self.lossFN_KD(prediction, label) + W_Beta * (self.lossFH_KLDiv(softStudent, softTeacher) * (thao_KD ** 2))
loss += (self.lossFN_DINO(StudentF, TeacherF) if not val_step else 0)
if not val_step:
self.logger.experiment.add_histogram ("Teacher", TeacherF[0])
self.logger.experiment.add_histogram ("Student", StudentF[1])
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.student.parameters(): # aqui
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
#return loss, PseudoLabels, predicted_labels
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
teacher_ps.data.mul_(self.momentum_teacher)
teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
return [optimizer]
class RARP_NVB_DINO_VAN(RARP_NVB_DINO_RestNet50_Deep):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher: float = 0.9995,
lr: float = 0.0001,
L1: float = None,
L2: float = 0
) -> None:
super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
self.in_dim = 512
self.student = van.van_b2(pretrained = True, num_classes = -1)
self.teacher_Features = van.van_b2(pretrained = True, num_classes = -1)
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
class RARP_NVB_DINO_ViT(RARP_NVB_DINO_RestNet50_Deep):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher: float = 0.9995,
lr: float = 0.0001,
L1: float = None,
L2: float = 0
) -> None:
super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
self.in_dim = 768
self.student = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
self.teacher_Features = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
self.student.heads = torch.nn.Identity()
self.teacher_Features.heads = torch.nn.Identity()
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
class RARP_MAE(L.LightningModule):
def __init__(
self,
backbone:str,
mask_ratio:float = 0.75,
patch_size: int = 16,
#img_size: int = 224,
lr: float = 1e-4,
hiden_channels = [512, 256, 128, 64, 32],
img_mean_std = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]],
activation_fn = torch.nn.ReLU(True)
):
super().__init__()
self.save_hyperparameters(ignore=["activation_fn", "img_mean_std"])
if img_mean_std is not None:
mean, std = img_mean_std
self.register_buffer("mean_IMG", torch.tensor(mean).view(3, 1, 1))
self.register_buffer("std_IMG", torch.tensor(std).view(3, 1, 1))
else:
self.mean_IMG = None
self.std_IMG = None
self.loss_fn = torch.nn.MSELoss()
match backbone:
case "resnet":
model = torchvision.models.resnet18()
model.fc = torch.nn.Identity()
self.encoder = model
self.encoder_out_dim = 512
case "van":
model = van.van_b2()
model.head = torch.nn.Identity()
self.encoder = model
self.encoder_out_dim = 512
case "levit":
self.encoder = timm.create_model("levit_192.fb_dist_in1k", pretrained=False, num_classes=0)
self.encoder_out_dim = 384
hiden_channels = [384, 256, 128, 64, 32]
in_channel = self.encoder_out_dim
layers = [
torch.nn.Linear(in_channel, hiden_channels[0]*7*7),
torch.nn.Unflatten(1, (hiden_channels[0], 7, 7)),
]
for out_channel in hiden_channels[1:]:
layers.append(torch.nn.ConvTranspose2d(in_channel, out_channel, kernel_size=3, stride=2, padding=1, output_padding=1))
layers.append(torch.nn.BatchNorm2d(out_channel))
layers.append(activation_fn)
in_channel = out_channel
#Last layer
layers.append(torch.nn.ConvTranspose2d(in_channel, 3, kernel_size=3, stride=2, padding=1, output_padding=1))
self.decoder = torch.nn.Sequential(*layers)
def _denormalize(self, tensor:torch.Tensor):
return tensor * self.std_IMG + self.mean_IMG
def _mask_patches(self, imgs):
B, C, H, W = imgs.shape
ph, pw = self.hparams.patch_size, self.hparams.patch_size
gh, gw = H // ph, W // pw
mask:torch.Tensor = (torch.rand(B, gh * gw, device=imgs.device) >= self.hparams.mask_ratio)
mask = mask.reshape(B, 1, gh, gw) # [B,1,gh,gw]
# expand mask to full image
mask = mask.repeat_interleave(ph, dim=2) # [B,1,gh*ph,gw]
mask = mask.repeat_interleave(pw, dim=3) # [B,1,gh*ph,gw*pw] == [B,1,H,W]
mask = mask.expand(-1, C, -1, -1)
imgs_masked = imgs * mask
return imgs_masked
def forward(self, data, val_step=False):
if not val_step:
data = [d.float() for d in data] #[0] original Img, [1] augmented Img
else:
data = [data.float(), data.float()]
imgs_masked = self._mask_patches(data[1])
latent = self.encoder(imgs_masked) # [B, 512]
reconstructed_image = self.decoder(latent) # [B, 3, 224, 224]
return reconstructed_image, data[0]
def _shared_step(self, batch, val_step=False):
img, _ = batch
res, true_img = self(img, val_step)
loss_MSE = self.loss_fn(res, true_img)
return loss_MSE, res, true_img
def training_step(self, batch, batch_idx):
loss, img, _ = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True, sync_dist=True)
if batch_idx % 500 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(img), 0, 1)
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
return loss
def validation_step(self, batch, batch_idx):
loss, img, target_img = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
if self.mean_IMG is not None and self.std_IMG is not None:
img = torch.clip(self._denormalize(img), 0, 1)
target_img = torch.clip(self._denormalize(target_img), 0, 1)
ssim = piq.ssim(img, target_img, data_range=1.0)
psnr = piq.psnr(img, target_img, data_range=1.0)
self.log("val_ssim", ssim, on_epoch=True)
self.log("val_psnr", psnr, on_epoch=True)
if batch_idx % 100 == 0:
grid = torchvision.utils.make_grid(img)
self.logger.experiment.add_image('val_reconstructed_images', grid, self.global_step)
def on_after_backward(self):
total_norm = 0.0
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.log("grad_norm", total_norm)
if total_norm < 1e-8:
self.log("grad_warning", "Vanishing gradient suspected!")
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
return [optimizer]
class RARP_Encoder_DINO(L.LightningModule):
def __init__(self,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
Teacher_T:float = 0.04,
Student_T:float = 0.1,
max_epochs:int = 100,
total_steps:int = None
) -> None:
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.momentum_teacher = momentum_teacher
self.out_dim = 65536
self.in_dim = 768 #512
self.embs_t = []
self.embs_s = []
#self.student = van.van_b2(num_classes = 0)
#self.teacher = van.van_b2(num_classes = 0)
#weights=torchvision.models.ConvNeXt_Small_Weights.DEFAULT
self.student = torchvision.models.convnext_small()
self.student.classifier[-1] = torch.nn.Identity()
self.teacher = torchvision.models.convnext_small()
self.teacher.classifier[-1] = torch.nn.Identity()
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True)
)
self.teacher = RARP_NVB_DINO_Wrapper(
self.teacher,
RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True)
)
self.teacher.load_state_dict(self.student.state_dict())
for parms in self.teacher.parameters():
parms.requires_grad = False
self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
def forward(self, data):
data = [d.float() for d in data]
dataTeacher, dataStudent = data[0:3], data
teacher_features = self.teacher(dataTeacher)
student_features = self.student(dataStudent)
return teacher_features, student_features
def _shared_step(self, batch):
img, _ = batch
t, s = self(img)
loss_Dino = self.lossFN_DINO(s, t)
return loss_Dino
def training_step(self, batch, batch_idx):
loss = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True, sync_dist=True)
return loss
def on_train_batch_end(self, outputs, batch, batch_idx):
#step = self.global_step
#
#m = 1.0 - (1.0 - self.hparams.momentum_teacher) * (
# (1 + math.cos(math.pi * step / self.hparamstotal_steps)) / 2
#)
with torch.no_grad():
for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher.parameters()):
teacher_ps.data.mul_(self.momentum_teacher)
teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
def validation_step(self, batch, batch_idx):
imgs, _ = batch
embs = self.teacher(imgs.float())[0]
self.embs_t.append(embs)
embs = self.student(imgs.float())[0]
self.embs_s.append(embs)
def on_validation_epoch_end(self):
emds = torch.cat(self.embs_t, dim=0).cpu().numpy()
self.embs_t.clear()
#silhouette
kmeans = KMeans(n_clusters=10, random_state=505).fit(emds)
sil = silhouette_score(emds, kmeans.labels_)
self.log("val_silhouette_teacher", sil)
emds = torch.cat(self.embs_s, dim=0).cpu().numpy()
self.embs_s.clear()
#silhouette
kmeans = KMeans(n_clusters=10, random_state=505).fit(emds)
sil = silhouette_score(emds, kmeans.labels_)
self.log("val_silhouette_student", sil)
def on_after_backward(self):
norms = [p.grad.data.norm(2).item() for p in self.parameters() if p.grad is not None]
avg_layer_norm = sum(norms) / len(norms)
self.log("avg_grad_norm", avg_layer_norm)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr)
return [optimizer]
class RARP_NVB_Classification_Head(torch.nn.Module):
def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.activation = activation_fn
if len (layer) == 0:
self.head = torch.nn.Linear(in_features, out_features)
else:
temp_head = []
next_input = in_features
for num in layer:
temp_head.append(torch.nn.Linear(next_input, num))
temp_head.append(self.activation)
temp_head.append(torch.nn.Dropout(0.4))
next_input = num
temp_head[-1] = torch.nn.Dropout(0.2)
temp_head.append(torch.nn.Linear(next_input, out_features))
self.head = torch.nn.Sequential(*temp_head)
del temp_head
def forward(self, x):
return self.head(x)
class RARP_Encoder_DINO_AUX_task(RARP_Encoder_DINO):
def __init__(
self,
momentum_teacher = 0.9995,
lr = 0.0001,
Teacher_T = 0.04,
Student_T = 0.1,
max_epochs = 100,
total_steps = None,
aux_lambda = 0.3
):
super().__init__(momentum_teacher, lr, Teacher_T, Student_T, max_epochs, total_steps)
self.angles = [-90, 0, 90, 180]
self.angles_labels = torch.from_numpy(LabelEncoder().fit_transform(self.angles))
self.classifier = RARP_NVB_Classification_Head(self.in_dim, len(self.angles), layer=[128], activation_fn=torch.nn.GELU())
self.lossFN_Aux = torch.nn.CrossEntropyLoss()
self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=len(self.angles))
self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=len(self.angles))
def _rotation_labels_generator(self, batch_size:int, angles_batch=None, rand:bool=False) -> torch.Tensor:
assert len(angles_batch) > 0, "Empty list, angles list shuld have more than 2 values"
if rand:
return torch.tensor([angles_batch[np.random.randint(len(angles_batch))] for _ in range(batch_size)])
else:
return torch.tensor([angles_batch[i % len(angles_batch)] for i in range(batch_size)])
def _rotate_img(self, imgs:torch.Tensor, angles_batch=[]) -> torch.Tensor:
assert len(angles_batch) > 0, "Empty list, angles list shuld have more than 2 values"
list_imgs = []
for i, angle_idx in enumerate(angles_batch):
list_imgs.append(torchvision.transforms.functional.rotate(imgs[i, ...], self.angles[angle_idx]))
return torch.stack(list_imgs, dim=0)
def forward(self, data, rot_data, val_step:bool=False):
if not val_step:
data = [d.float() for d in data]
dataTeacher, dataStudent = data[0:3], data
else:
data = data.float()
dataTeacher, dataStudent = data, data
dataAux = rot_data.float()
teacher_features = self.teacher(dataTeacher)
student_features = self.student(dataStudent)
aux_rot = self.student.backbone(dataAux)
aux_rot = self.classifier(aux_rot)
return teacher_features, student_features, aux_rot
def _shared_step(self, batch, val_step:bool=False):
img, _ = batch
batch_size = img.size(0) if val_step else img[0].size(0)
current_device = img.device if val_step else img[0].device
labels = self._rotation_labels_generator(batch_size, self.angles_labels, not val_step).to(current_device)
rot_img = self._rotate_img(img if val_step else img[0], labels).to(current_device)
t, s, aux = self(img, rot_img, val_step)
if val_step:
self.embs_t.append(t[0])
self.embs_s.append(s[0])
aux = torch.softmax(aux, dim=1)
loss_aux = self.lossFN_Aux(aux, labels)
loss_aux = self.hparams.aux_lambda * loss_aux
loss_Dino = self.lossFN_DINO(s, t) if not val_step else torch.tensor(0, device=current_device, dtype=torch.float32)
loss = loss_Dino + loss_aux
return loss, labels, aux, (loss_Dino, loss_aux)
def training_step(self, batch, batch_idx):
loss, true_labels, pred_labels, losses = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True, sync_dist=True)
self.train_acc.update(pred_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_loss_DINO", losses[0], on_epoch=True, on_step=False)
self.log("train_loss_AUX", losses[1], on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, pred_labels, losses = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(pred_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_loss_DINO", losses[0], on_epoch=True, on_step=False)
self.log("val_loss_AUX", losses[1], on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.AdamW([
{"params": self.student.parameters(), "lr": self.lr},
{"params": self.classifier.parameters(), "lr": self.lr}
])
return [optimizer]
class RARP_NVB_DINO_MultiTask(L.LightningModule):
# Define a hook function to capture the output
def _hook_fn_Student(self, module, input, output):
self.last_conv_output_S = output
def _hook_fn_Teacher(self, module, input, output):
self.last_conv_output_T = output
def __init__(
self,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
L1:float = None,
L2:float = 0,
std: float = None,
mean: float = None,
SoftAdptAlgo:int = 0,
SoftAdptBeta:float = 0.1,
Teacher_T:float = 0.04,
Student_T:float = 0.1,
intermittent:bool = False
) -> None:
super().__init__()
self.intermittent_train = intermittent
self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
self.lr = lr
self.Lambda_L1 = L1
self.Lambda_L2 = L2
self.Teacher_t = Teacher_T
self.Studet_T = Student_T
self.momentum_teacher = momentum_teacher
self.out_dim = 1024
self.in_dim = 512
self.weights = torch.tensor([1,1,1])
self.softAdapt = NormalizedSoftAdapt(SoftAdptBeta) if SoftAdptAlgo == 1 else LossWeightedSoftAdapt(SoftAdptBeta)
self.loss_history = {
'loss_DINO': [],
'loss_Reconstruction': [],
'loss_Binary': [],
}
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
self.student = van.van_b2(pretrained = True, num_classes = 0)
self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0)
self.decoder = DynamicDecoder(input_channels=1024)
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
)
self.teacher_Features.load_state_dict(self.student.state_dict())
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
self.ReconstructionLoss = torch.nn.MSELoss()
self.last_conv_output_T = None
self.last_conv_output_S = None
self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher)
self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
self.clasiffier = torch.nn.Sequential(
torch.nn.Linear(1024, 128),
torch.nn.SiLU(True),
torch.nn.Dropout(0.4),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Dropout(0.2),
torch.nn.Linear(8, 1)
)
print(f"lr={self.lr}, L1={self.Lambda_L1}")
def _denormalize(self, tensor:torch.Tensor):
# Move mean and std to the same device as the input tensor
mean = self.mean_IMG.to(tensor.device)
std = self.std_IMG.to(tensor.device)
return tensor * std + mean
def _calc_L1(self, params):
l1 = 0
for p in params:
l1 += torch.sum(torch.abs(p))
return self.Lambda_L1 * l1
def _calc_weights(self, log_weights:bool = True):
self.weights = self.softAdapt.get_component_weights(
torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]),
torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_Reconstruction"]),
torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
verbose=False
)
if log_weights:
self.log("W_loss_img", self.weights[1], on_epoch=True, on_step=False)
self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False)
self.log("W_loss_GT", self.weights[2], on_epoch=True, on_step=False)
self.loss_history = {
'loss_DINO': [],
'loss_Reconstruction': [],
'loss_Binary': [],
}
def forward(self, data, val_step:bool = True):
if val_step:
data = data.float()
dataTeacher, dataStudent = data, data
else:
data = [d.float() for d in data]
dataTeacher, dataStudent = data[1:3], data
TeacherDino = self.teacher_Features(dataTeacher)
Student = self.student(dataStudent)
if not val_step:
NumChunks = len(dataStudent)
S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
reconstructed_image = self.decoder(cat_features)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
pred = self.clasiffier(Cont_Net)
return pred, (Student, TeacherDino), reconstructed_image
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
prediction, features, new_image = self(img, val_step)
StudentF, TeacherF = features
if isinstance(self.clasiffier, torch.nn.Sequential):
if self.clasiffier[-1].out_features == 1:
prediction = prediction.flatten()
elif isinstance(self.clasiffier, (NOAH, RARP_NVB_Classification_Head)):
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
#DINO Loss
loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
#Clasificator
loss_HL = self.lossFN(prediction, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
if not val_step:
if self.Lambda_L1 is not None:
loss_HL += self._calc_L1(self.clasiffier.parameters())
if self.Lambda_L2 > 0:
l2_reg = 0.0
for param in self.clasiffier.parameters():
l2_reg += torch.norm(param, 2) ** 2
loss_HL += self.Lambda_L2 * l2_reg
self.loss_history["loss_DINO"].append(loss_Dino.item())
self.loss_history["loss_Reconstruction"].append(loss_img.item())
self.loss_history["loss_Binary"].append(loss_HL.item())
loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
def on_train_epoch_start(self):
if self.current_epoch % 2 == 0 and self.current_epoch != 0:
self._calc_weights()
if self.intermittent_train and self.current_epoch != 0:
par_epoch = (self.current_epoch % 2 == 0)
for parms in self.student.backbone.parameters():
parms.requires_grad = par_epoch
for parms in self.decoder.parameters():
parms.requires_grad = not par_epoch
for parms in self.clasiffier.parameters():
parms.requires_grad = not par_epoch
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("train_loss_DINO", losses[0], on_epoch=True, on_step=False)
self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
return loss
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
teacher_ps.data.mul_(self.momentum_teacher)
teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
#self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
def on_after_backward(self):
total_norm = 0.0
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.log("grad_norm", total_norm)
if total_norm < 1e-8:
self.log("grad_warning", "Vanishing gradient suspected!")
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("val_loss_DINO", losses[0], on_epoch=True, on_step=False)
self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels, losses = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
if self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
imgOrig = torch.clip(self._denormalize(batch[0])/255, 0, 1)
imgOrig = imgOrig[:, [2, 1, 0], :, :]
imgReconstruction = torch.cat((imgOrig, imgReconstruction), dim=0)
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) #, weight_decay=self.Lambda_L2
return [optimizer]
class Scalar_AttnPooling(torch.nn.Module):
def __init__(self, hidden_dim=64):
super().__init__()
self.proj = torch.nn.Linear(1, hidden_dim)
self.attn = torch.nn.Linear(hidden_dim, 1)
self.pooler_cls = torch.nn.Linear(1, 1)
def forward(self, logits_bt):
x = logits_bt.unsqueeze(-1) #[B, T, 1]
h = torch.tanh(self.proj(x)) #[B, T, hidden_dim]
a = self.attn(h).squeeze(-1) #[B, T]
w = torch.nn.functional.softmax(a, 1) # [B, T] Weights for each T, to idetnfy the best frame for classification
z = (w.unsqueeze(-1) * x).sum(dim=1) # [B, 1]
z = self.pooler_cls(z)
return z
class Scalar_TCN(torch.nn.Module):
def __init__(self, hidden_dim=64, layers=3):
super().__init__()
self.in_proj = torch.nn.Conv1d(1, hidden_dim, kernel_size=1)
blocks = []
for i in range(layers):
d = 2 ** i
blocks += [
torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=d, dilation=d),
torch.nn.SiLU(),
torch.nn.Dropout(0.3),
torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1),
torch.nn.SiLU(),
torch.nn.Dropout(0.2)
]
self.tcn = torch.nn.Sequential(*blocks)
self.head = torch.nn.Linear(hidden_dim, 1)
def forward (self, logits_bt):
x = logits_bt.unsqueeze(1) #[B, 1, T] to do the Conv over the channel or the logits values
x = self.in_proj(x) #[B, hidden_dim, T]
x = self.tcn(x) #[B, hidden_dim, T]
x = x.mean(dim=2) #[B, hidden_dim] global average over T
x = self.head(x) #[B, 1]
return x
class ModuleWrapper(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x, dummy_arg=None):
assert dummy_arg is not None
x = self.module(x)
return x
class Chomp1d(torch.nn.Module):
"""
Remove extra padding at the end to maintain causality.
If you pad (padding) at left, you may need to chomp off the right extra.
"""
def __init__(self, chomp_size):
super().__init__()
self.chomp_size = chomp_size
def forward(self, x):
# x has shape [B, C, T]
if self.chomp_size == 0:
return x
return x[:, :, :-self.chomp_size]
class TemporalBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, padding, dropout=0.0):
"""
A residual block in TCN with two dilated conv layers (same dilation).
"""
super().__init__()
self.conv1 = torch.nn.Conv1d(in_channels, out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
self.chomp1 = Chomp1d(padding)
self.relu1 = torch.nn.ReLU()
self.dropout1 = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv1d(out_channels, out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
self.chomp2 = Chomp1d(padding)
self.relu2 = torch.nn.ReLU()
self.dropout2 = torch.nn.Dropout(dropout)
self.downsample = (torch.nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels else None)
self.relu = torch.nn.ReLU()
def forward(self, x):
"""
x: [B, in_channels, T]
returns: [B, out_channels, T]
"""
out = self.conv1(x)
out = self.chomp1(out)
out = self.relu1(out)
out = self.dropout1(out)
out = self.conv2(out)
out = self.chomp2(out)
out = self.relu2(out)
out = self.dropout2(out)
res = x if self.downsample is None else self.downsample(x)
return self.relu(out + res)
class TemporalConvNet(torch.nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=3, dropout=0.0):
"""
num_inputs: number of input channels (features)
num_channels: list of output channels per layer, e.g. [64, 64, 128]
kernel_size: convolution kernel size (e.g. 3)
dropout: dropout rate in blocks
"""
super().__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
in_ch = num_inputs if i == 0 else num_channels[i - 1]
out_ch = num_channels[i]
dilation = 2 ** i
# padding should be such that the output has length T (causal)
padding = (kernel_size - 1) * dilation
layers.append(
TemporalBlock(in_ch, out_ch,
kernel_size=kernel_size,
stride=1,
dilation=dilation,
padding=padding,
dropout=dropout)
)
self.network = torch.nn.Sequential(*layers)
def forward(self, x):
"""
x: [B, T, C_in]
returns: [B, C_out_last, T]
"""
x = x.permute(0, 2, 1) #[B, C_in, T]
x = self.network (x) #[B, C_out, T]
x = x.permute(0, 2, 1) #[B, T, C_out]
return x
class RARP_NVB_DINO_MultiTask_A5_Video(L.LightningModule):
def __init__(
self,
base_model_path = None,
lr = 0.0001,
wd = 0.01,
L1 = None,
L2 = 0,
std = None,
mean = None,
head_type:int = 0, #None = 0, linear = 1, Attn. Pooling = 2, TCN = 3
chunks_loading:int = 50
):
super().__init__()
self.lr = lr
self.wd = wd
self.chunks = chunks_loading
if base_model_path is not None:
self.base_model = RARP_NVB_DINO_MultiTask.load_from_checkpoint(base_model_path)
self.base_model.eval()
for param in self.base_model.parameters():
param.requires_grad = False
else:
self.base_model = RARP_NVB_DINO_MultiTask()
self.lossFN = torch.nn.BCEWithLogitsLoss()
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
#match(head_type):
# case 1:
# #Linear
# self.head = torch.nn.Linear(600, 1)
# case 2:
# #Attn. pooling
# self.head = Scalar_AttnPooling(128)
# case 3:
# #TCN
# self.head = Scalar_TCN(64, 1)
# case _:
self.head = torch.nn.Linear(600, 1)
def _shared_video_step(self, batch:list[torch.Tensor], val_step:bool = False):
video, label = batch
B, T, C, H, W = video.shape
video = video.float() #[B, T, C, H, W]
label = label.float() #[B]
chunk_T = self.chunks
pred_bt = []
def _fn(inp):
pred, *_ = self.base_model(inp)
return pred
for t0 in tqdm(range(0, T, chunk_T), desc=f"Video Analysis in {chunk_T} chunk", leave=False):
t1 = min(T, t0 + chunk_T)
x = video[:, t0:t1].reshape(-1, C, H, W).contiguous(memory_format=torch.channels_last)
pred = torch_ckp.checkpoint(_fn, x)
pred_bt.append(pred.view(B, t1-t0, -1))
pred_video = torch.cat(pred_bt, dim=1).flatten(start_dim=1)
pred_video = self(pred_video, val_step) #[B, 1]
pred_video = pred_video.flatten() #[B] to match labels shape
predicted_labels = torch.sigmoid(pred_video)
loss = self.lossFN(pred_video, label)
return loss, label, predicted_labels
def forward(self, data, val_step:bool = True):
#if self.head is None:
# pred_video = data.mean(dim=1)
#else:
pred_video = self.head(data)
return pred_video
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_video_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_video_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels = self._shared_video_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.head.parameters(), lr=self.lr, weight_decay=self.wd) #, weight_decay=self.Lambda_L2
return [optimizer]
class RARP_NVB_VIDEO_3D_ResNet(L.LightningModule):
def __init__(
self,
lr = 0.0001,
wd = 0.01,
L1 = None,
L2 = 0,
std = None,
mean = None,
chunks_loading:int = 50,
str_path:str = None
):
super().__init__()
self.lr = lr
self.wd = wd
self.chunks = chunks_loading
if str_path is None:
base_2D_model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
self.check_pt = True
self.base_model = I3DResNet50( base_2D_model )
self.features = self.base_model.fc.in_features
self.base_model.fc = torch.nn.Linear(self.features, 1)
else:
base_2D_model = torchvision.models.resnet50()
in_f = base_2D_model.fc.in_features
base_2D_model.fc = torch.nn.Linear(in_f, 1)
base_2D_model.load_state_dict(torch.load(str_path))
self.check_pt = True
self.base_model = I3DResNet50( base_2D_model )
self.lossFN = torch.nn.BCEWithLogitsLoss()
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
self.test_probs = []
self.test_targets = []
def _shared_video_step(self, batch:list[torch.Tensor], val_step:bool = False):
video, label = batch
video = video.float() #[B, T, C, H, W]
label = label.float() #[B]
pred_video = self(video) #[B, T, 1]
#pred_video = pred_video.mean(dim=1) #[B, 1]
pred_video = pred_video.flatten() #[B] to match labels shape
predicted_labels = torch.sigmoid(pred_video)
loss = self.lossFN(pred_video, label)
return loss, label, predicted_labels
def forward(self, video:torch.Tensor):
T = video.shape[1]
chunk_T = self.chunks
pred_bt = []
assert T > chunk_T, f"The Time dim is smaller, than chunk size [T={T}, chunk={chunk_T}]"
def _fn(inp):
pred = self.base_model(inp)
return pred
video = video.permute(0, 2, 1, 3, 4)
for t0 in tqdm(range(0, T, chunk_T), desc=f"Video Analysis in {chunk_T} chunk", leave=False): # Loop for each chunk
t1 = min(T, t0 + chunk_T)
x = video[:, :, t0:t1]#.contiguous(memory_format=torch.channels_last_3d)
pred = torch_ckp.checkpoint(_fn, x, use_reentrant=False) #froward to CNN and checkpoint grads
pred_bt.append(pred) #[B, 1]
pred_video = torch.stack(pred_bt, dim=0).mean(dim=0)
#pred_video = torch.mean(pred_bt, dim=1) #concat all chunks -> [B, T, C, H, W]
#pred_video = self.head(pred_video)
return pred_video
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd) #, weight_decay=self.Lambda_L2
return [optimizer]
def on_after_backward(self):
norms = [p.grad.data.norm(2).item() for p in self.parameters() if p.grad is not None]
avg_layer_norm = sum(norms) / len(norms)
self.log("avg_grad_norm", avg_layer_norm)
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_video_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_video_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels = self._shared_video_step(batch, True)
self.test_probs.append(predicted_labels)
self.test_targets.append(true_labels)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def on_test_epoch_end(self):
predicted_labels = torch.cat(self.test_probs).to(self.device)
true_labels = torch.cat(self.test_targets).to(self.device).int()
acc = torchmetrics.Accuracy('binary').to(self.device)(predicted_labels, true_labels)
precision = torchmetrics.Precision('binary').to(self.device)(predicted_labels, true_labels)
recall = torchmetrics.Recall('binary').to(self.device)(predicted_labels, true_labels)
auc = torchmetrics.AUROC('binary').to(self.device)(predicted_labels, true_labels)
f1Score = torchmetrics.F1Score('binary').to(self.device)(predicted_labels, true_labels)
specificty = torchmetrics.Specificity("binary").to(self.device)(predicted_labels, true_labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", f"*{self.current_epoch}*", "0"]
]
aucCurve = torchmetrics.ROC("binary").to(self.device)
fpr, tpr, thhols = aucCurve(predicted_labels, true_labels)
index = torch.argmax(tpr - fpr)
th1 = thhols[index].item()
accY = torchmetrics.Accuracy('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(self.device)(predicted_labels, true_labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", f"-{self.current_epoch}-", f"{(tpr - fpr)[index].item():.4f}"])
df = pd.DataFrame(table, columns=["Threshold", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint", "J"])
out_path = f"{self.logger.log_dir}/roc_points_epoch_{self.global_step}.csv"
df.to_csv(out_path, index=False)
class RARP_NVB_DINO_MultiTask_A6_Video(L.LightningModule):
def __init__(
self,
lr = 0.0001,
wd = 0.01,
L1 = None,
L2 = 0,
std = None,
mean = None,
head_type:int = 0, #None = 0, linear = 1, Attn. Pooling = 2, TCN = 3, Replace Head =4
chunks_loading:int = 50,
):
super().__init__()
self.lr = lr
self.wd = wd
self.chunks = chunks_loading
self.head_type = head_type
self.check_pt = True
self.base_model = van.van_b2(pretrained = True, num_classes = 0)
self.num_features_base_model = 512
#self.base_model_wrapper = ModuleWrapper(self.base_model)
self.lossFN = torch.nn.BCEWithLogitsLoss()
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
self.test_probs = []
self.test_targets = []
match(self.head_type):
case 1:
self.head = TemporalConvNet(self.num_features_base_model, [128, 8, 1])
case _:
self.head = None
def _shared_video_step(self, batch:list[torch.Tensor], val_step:bool = False):
video, label = batch
video = video.float() #[B, T, C, H, W]
label = label.float() #[B]
pred_video = self(video, val_step) #[B, T, 1]
pred_video = pred_video.mean(dim=1) #[B, 1]
pred_video = pred_video.flatten() #[B] to match labels shape
predicted_labels = torch.sigmoid(pred_video)
loss = self.lossFN(pred_video, label)
return loss, label, predicted_labels
def forward(self, video:torch.Tensor, val_step:bool = True):
B, T, C, H, W = video.shape
chunk_T = self.chunks
pred_bt = []
def _fn(inp):
pred = self.base_model(inp)
return pred
for t0 in tqdm(range(0, T, chunk_T), desc=f"Video Analysis in {chunk_T} chunk", leave=False): # Loop for each chunk
t1 = min(T, t0 + chunk_T)
x = video[:, t0:t1].reshape(-1, C, H, W).contiguous(memory_format=torch.channels_last) # reshape from [B, chunk_T, C, H, W] to [B*chunk_T, C, H, W] and make the tensor GPU optimization
pred = torch_ckp.checkpoint(_fn, x) #froward to CNN and checkpoint grads
pred_bt.append(pred.view(B, t1-t0, -1)) #apped to output array and reshape to [B, chunk_T, C, H, W]
pred_video = torch.cat(pred_bt, dim=1) #concat all chunks -> [B, T, C, H, W]
pred_video = self.head(pred_video)
return pred_video
def on_after_backward(self):
norms = [p.grad.data.norm(2).item() for p in self.parameters() if p.grad is not None]
avg_layer_norm = sum(norms) / len(norms)
self.log("avg_grad_norm", avg_layer_norm)
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_video_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_video_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels = self._shared_video_step(batch, True)
self.test_probs.append(predicted_labels)
self.test_targets.append(true_labels)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def on_test_epoch_end(self):
predicted_labels = torch.cat(self.test_probs).to(self.device)
true_labels = torch.cat(self.test_targets).to(self.device).int()
acc = torchmetrics.Accuracy('binary').to(self.device)(predicted_labels, true_labels)
precision = torchmetrics.Precision('binary').to(self.device)(predicted_labels, true_labels)
recall = torchmetrics.Recall('binary').to(self.device)(predicted_labels, true_labels)
auc = torchmetrics.AUROC('binary').to(self.device)(predicted_labels, true_labels)
f1Score = torchmetrics.F1Score('binary').to(self.device)(predicted_labels, true_labels)
specificty = torchmetrics.Specificity("binary").to(self.device)(predicted_labels, true_labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", f"*{self.current_epoch}*", "0"]
]
aucCurve = torchmetrics.ROC("binary").to(self.device)
fpr, tpr, thhols = aucCurve(predicted_labels, true_labels)
index = torch.argmax(tpr - fpr)
th1 = thhols[index].item()
accY = torchmetrics.Accuracy('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(self.device)(predicted_labels, true_labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(self.device)(predicted_labels, true_labels)
table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", f"-{self.current_epoch}-", f"{(tpr - fpr)[index].item():.4f}"])
df = pd.DataFrame(table, columns=["Threshold", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint", "J"])
out_path = f"{self.logger.log_dir}/roc_points_epoch_{self.global_step}.csv"
df.to_csv(out_path, index=False)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd) #, weight_decay=self.Lambda_L2
return [optimizer]
class RARP_NVB_DINO_MultiTask_Unet(RARP_NVB_DINO_MultiTask):
def _encoder_hool_fn(self, module, input, output):
self.feature_maps.append(output)
def _register_encoder_hooks(self):
for layer in self.list_blocks:
self.hooks.append(layer.register_forward_hook(self._encoder_hool_fn))
def __init__(
self,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
L1:float = None,
L2:float = 0,
std: float = None,
mean: float = None,
SoftAdptAlgo:int = 0,
SoftAdptBeta:float = 0.1
):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
self.hooks = []
self.feature_maps = []
self.list_blocks = [
self.student.backbone.block1[-1],
self.student.backbone.block2[-1],
self.student.backbone.block3[-1],
self.student.backbone.block4[-2],
]
self._register_encoder_hooks()
self.decoder = DecoderUnet(1024)
def forward(self, data, val_step:bool = True):
self.feature_maps = []
if val_step:
data = data.float()
dataTeacher, dataStudent = data, data
else:
data = [d.float() for d in data]
dataTeacher, dataStudent = data[1:3], data
TeacherDino = self.teacher_Features(dataTeacher)
Student = self.student(dataStudent)
if not val_step:
NumChunks = len(dataStudent)
S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
if not val_step:
NumChunks = len(dataStudent)
temp = []
for f_maps in self.feature_maps:
temp.append(torch.cat(f_maps.chunk(NumChunks)[1:3], dim=0))
self.feature_maps = temp
del temp
self.feature_maps.append(cat_features)
reconstructed_image = self.decoder(self.feature_maps)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
pred = self.clasiffier(Cont_Net)
return pred, (Student, TeacherDino), reconstructed_image
class RARP_NVB_DINO_MultiTask_MultiLabel(RARP_NVB_DINO_MultiTask):
def __init__(
self,
TypeLoss=TypeLossFunction.BCEWithLogits,
momentum_teacher = 0.9995,
lr = 0.0001,
L1 = None,
L2 = 0,
std = None,
mean = None,
SoftAdptAlgo = 0,
SoftAdptBeta = 0.1,
Num_Lables = 2
):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
self.lossFN = torch.nn.BCEWithLogitsLoss()
self.clasiffier = torch.nn.Sequential(
torch.nn.Linear(1024, 128),
torch.nn.Dropout(0.4),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.Dropout(0.2),
torch.nn.SiLU(True),
torch.nn.Linear(8, Num_Lables)
)
self.train_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
self.val_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
self.test_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
self.f1ScoreTest = torchmetrics.F1Score('multilabel', num_labels=Num_Lables)
class RARP_NVB_DINO_MultiTask_v2(RARP_NVB_DINO_MultiTask):
def __init__(
self,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher = 0.9995,
lr = 0.0001,
L1 = None,
L2 = 0,
std = None,
mean = None
):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean)
self.train_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
self.val_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
self.test_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
self.f1ScoreTest = torchmetrics.F1Score('multiclass', num_classes=4)
self.lossFN = torch.nn.CrossEntropyLoss()
self.softAdapt = LossWeightedSoftAdapt(0.1) #NormalizedSoftAdapt(0.1)
self.clasiffier = torch.nn.Sequential(
torch.nn.Linear(1024, 512),
torch.nn.Dropout(0.2),
torch.nn.SiLU(True),
torch.nn.Linear(512, 256),
#torch.nn.Dropout(0.4),
torch.nn.SiLU(True),
torch.nn.Linear(256, 128),
#torch.nn.Dropout(0.4),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
#torch.nn.Dropout(0.2),
torch.nn.SiLU(True),
torch.nn.Linear(8, 4)
)
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
prediction, features, new_image = self(img, val_step)
StudentF, TeacherF = features
predicted_labels = torch.softmax(prediction,dim=1)
orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
label = torch.cat([label for _ in range(len(TeacherF))], dim=0) if not val_step else label
#DINO Loss
loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
#Clasificator
loss_HL = self.lossFN(predicted_labels, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
if not val_step:
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.clasiffier.parameters(): # aqui
loss_l1 += torch.sum(torch.abs(params))
loss_HL += self.Lambda_L1 * loss_l1
if self.Lambda_L2 > 0:
l2_reg = 0.0
for param in self.clasiffier.parameters():
l2_reg += torch.norm(param, 2) ** 2
loss_HL += self.Lambda_L2 * l2_reg
self.loss_history["loss_DINO"].append(loss_Dino.item())
self.loss_history["loss_Reconstruction"].append(loss_img.item())
self.loss_history["loss_Binary"].append(loss_HL.item())
loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
class RARP_NVB_RN50_VAN_V2 (RARP_NVB_Model):
# Define a hook function to capture the output
def _hook_fn(self, module, input, output):
self.last_conv_output = output
def __init__(
self,
PseudoEstimator:str = None,
threshold:float = 0.5,
InitWeight=None,
TypeLoss=TypeLossFunction.CrossEntropy,
PseudoLables:bool=True,
HParameter = {},
std: float = None,
mean: float = None,
**kwargs
) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
#self.RARP_RestNet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
self.RARP_RestNet50.model.fc = torch.nn.Identity()
for parms in self.RARP_RestNet50.model.parameters():
parms.requires_grad = False
self.RARP_RestNet50 = torch.nn.Sequential(*list(self.RARP_RestNet50.model.children())[:-2])
self.decoder = Decoder()
self.FeaturesLoss = FeatureAlignmentLoss()
self.ReconstructionLoss = torch.nn.MSELoss()
match (TypeError):
case TypeLossFunction.ContrastiveLoss:
self.lossFN = ContrastiveLoss()
case _:
pass
self.model = van.van_b2(pretrained = True, num_classes = 1)
self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04)
self.lr = getNVL(HParameter, "lr", 1.0E-4)
print(f"lr={self.lr}, L1={self.Lambda_L1}")
# Initialize a variable to store the output
self.last_conv_output = None
self.model.block4[-1].mlp.dwconv.register_forward_hook(self._hook_fn)
def forward(self, data):
dataTeacher, dataStudent, _ = data
dataStudent = dataStudent.float()
dataTeacher = dataTeacher.float()
feature_Teacher = self.RARP_RestNet50 (dataTeacher)
#feature_Student = self.model.forward_features(dataStudent)
pred = self.model(dataStudent)
feature_Student = self.last_conv_output
#feature_Teacher = torch.nn.functional.adaptive_avg_pool1d(feature_Teacher, feature_Student.size(-1)) #Re-size output vector to macth VAN
cat_features = (feature_Student + feature_Teacher) / 2
reconstructed_image = self.decoder(cat_features)
feature_Student = torch.nn.functional.adaptive_avg_pool2d(feature_Student, (1, 1)).flatten(1)
feature_Teacher = torch.nn.functional.adaptive_avg_pool2d(feature_Teacher, (1, 1)).flatten(1)
return pred, (feature_Student, feature_Teacher), reconstructed_image
def _shared_step(self, batch):
img, label = batch
label = label.float()
orignalImg = img[2].float()
prediction, features, new_image = self(img)
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
#cosine similarity
loss_cosine = self.FeaturesLoss(features[0], features[1])
#Clasificator
loss_hl = self.lossFN(prediction, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
loss = loss_cosine + loss_hl + loss_img
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.model.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
return loss, label, predicted_labels, (loss_cosine, loss_hl, loss_img, new_image)
def _denormalize(self, tensor:torch.Tensor):
# Move mean and std to the same device as the input tensor
mean = self.mean_IMG.to(tensor.device)
std = self.std_IMG.to(tensor.device)
return tensor * std + mean
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("train_loss_cosSim", losses[0], on_epoch=True, on_step=False)
self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("val_loss_cosSim", losses[0], on_epoch=True, on_step=False)
self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
if self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
return optimizer
class RARP_NVB_ResNet50_VAN(RARP_NVB_Model):
def __init__(
self,
PseudoEstimator:str = None,
threshold:float = 0.5,
InitWeight=None,
TypeLoss=TypeLossFunction.CrossEntropy,
PseudoLables:bool=True,
HParameter = {},
**kwargs
) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
self.threshold = threshold
self.PseudoLables = PseudoLables
match (TypeError):
case TypeLossFunction.ContrastiveLoss:
self.lossFN = ContrastiveLoss()
case _:
pass
for parms in self.RARP_RestNet50.model.parameters():
parms.requires_grad = False
self.model = van.van_b2(pretrained = True, num_classes = 1)
self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04)
self.lr = getNVL(HParameter, "lr", 1.0E-4)
self.W_Apha = getNVL(HParameter, "Alpha", 0.60)
self.W_Beta = 1 - getNVL(HParameter, "Alpha", 0.60)
self.thao_KD = getNVL(HParameter, "Thao", 5)
print(f"A={self.W_Apha}; B={self.W_Beta}, T={self.thao_KD}, lr={self.lr}, L1={self.Lambda_L1}")
#self.Lambda_L1 = None
#self.scheduler = True
#self.lr = 1.74E-2
def forward(self, data):
#dataTeacher, dataStudent = data
if isinstance(data, tuple):
dataTeacher, dataStudent = data
elif isinstance(data, torch.Tensor):
dataTeacher, dataStudent = (data, data)
RN50Pred = torch.sigmoid(self.RARP_RestNet50(dataTeacher).flatten())
PseudoLabels = (RN50Pred > self.threshold) * 1.0
dataStudent = dataStudent.float()
pred = self.model(dataStudent)
return pred, PseudoLabels, RN50Pred
def _shared_step(self, batch):
img, label = batch
if self.InitWeight is not None:
self.lossFN.weight = self.InitWeight[label]
label = label.float()
prediction, PseudoLabels, predictionRN50 = self(img) #.flatten()
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
#
# L[B] => BCEWithLogits
# L[C] => ContrastiveLoss
#Loss L[1] = L[B_HL] + L[B_PL]
#Loss L[2] = L[C_HL]
#Loss L[3] = L[C_HL] + L[C_PL]
#Loss L[4] = L[C_y_hat]
#Loss L[5] = L[KLD_PL] + L[B_PL]
#Loss L[6] = L[MSE_PL] + L[B_PL]
#Loss L[7] = L[B_HL] + L[BCE_PL]
#Loss L[8] = FocalLoss*L[JSD]
#Nuveo
#thao_KD = 5.0
#w_alpha, w_beta = (0.60, 0.40)
#L[7]:
softTeacher = torch.sigmoid(predictionRN50/self.thao_KD)
softStudent = torch.sigmoid(prediction/self.thao_KD)
loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
loss_hl = self.lossFN(prediction, label)
loss = self.W_Apha * loss_hl + self.W_Beta * loss_sl #/ (self.thao_KD ** 2)
#loss = w_alpha * self.lossFN(prediction, label) + w_beta * self.lossFN(prediction, PseudoLabels) #L[1]
#loss = w_alpha * (torch.nn.KLDivLoss()(softStudent, softTeacher) * (thao_KD ** 2)) + w_beta * self.lossFN(prediction, PseudoLabels) #L[5]
#loss = w_alpha * torch.nn.functional.mse_loss(softStudent, softTeacher) + w_beta * self.lossFN(prediction, PseudoLabels) #L[6]
#loss = self.lossFN(predictionRN50, predicted_labels, label) #L[2]
#loss = self.lossFN(predictionRN50, predicted_labels, label) + self.lossFN(predictionRN50, predicted_labels, PseudoLabels) #L[3]
#y_hat = (PseudoLabels != label) * 1
#loss = self.lossFN(predictionRN50, predicted_labels, y_hat) #L[4]
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.model.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
return loss, (PseudoLabels if self.PseudoLables else label), predicted_labels, (loss_hl, loss_sl)
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_lh", losses[0], on_step=False, on_epoch=True)
self.log("train_ld", losses[1], on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_lh", losses[0], on_step=False, on_epoch=True)
self.log("val_ld", losses[1], on_step=False, on_epoch=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels, _ = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
if self.scheduler:
scheduler = CosineAnnealingLR(
optimizer,
max_epochs=150,
warmup_epochs=8,
warmup_start_lr=0.001,
eta_min=0.0001
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
return optimizer
class RARP_NVB_SSL_RestNet50_Deep(RARP_NVB_ResNet50_VAN):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
InitWeight=None,
TypeLoss=TypeLossFunction.CrossEntropy,
PseudoLables: bool = True,
**kwargs
) -> None:
super().__init__(None, threshold, InitWeight, TypeLoss, PseudoLables, **kwargs)
self.RARP_RestNet50 = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
#self.model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
tempFC_ft = 2048
self.model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(in_features=tempFC_ft, out_features=128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
for parms in self.RARP_RestNet50.model.parameters():
parms.requires_grad = False
class RARP_NVB_MobileNetV2(RARP_NVB_Model):#class RARP_NVB_MobileNetV2(RARP_NVB_Model_BCEWithLogitsLoss):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_EfficientNetV2(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_Vit_b_16(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
super().__init__(InitWeight, TypeLoss)
self.model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
tempFC_ft = self.model.heads[0].in_features
self.model.heads[0] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_DenseNet169(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
super().__init__(InitWeight, TypeLoss)
self.model = torchvision.models.densenet169(weights=torchvision.models.DenseNet169_Weights.DEFAULT)
inFeatures = self.model.classifier.in_features
self.model.classifier = torch.nn.Linear(in_features=inFeatures, out_features=1)
class RARP_NVB_RestNet50_old(L.LightningModule):
def __init__(self, InitialWeigth = 1) -> None:
super().__init__()
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
self.lossFN = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([InitialWeigth]))
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def training_step(self, batch, batch_idx):
img, label = batch
label = label.float()
prediction = self(img)[:,0]
loss = self.lossFN(prediction, label)
self.log("Train Loss", loss)
self.log("Step Train Acc", self.train_acc(torch.sigmoid(prediction), label.int()))
return loss
def on_train_epoch_end(self):
self.log("Train Acc", self.train_acc.compute())
def validation_step(self, batch, batch_idx):
img, label = batch
label = label.float()
prediction = self(img)[:,0]
loss = self.lossFN(prediction, label)
self.log("Val Loss", loss)
self.log("Step Val Acc", self.val_acc(torch.sigmoid(prediction), label.int())) #train_acc
return loss
def on_validation_epoch_end(self):
self.log("Val Acc", self.val_acc.compute())
def configure_optimizers(self):
return [self.optimizer]