import math
from typing import Any, Union
import torch
import torchvision
import torchmetrics
import torchmetrics.classification
import lightning as L
from enum import Enum
from sklearn.preprocessing import LabelEncoder
import timm
import van
import numpy as np
from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt
from noah import NOAH
import piq
def js_divergence_sigmoid(p_logits, q_logits):
p_probs = torch.sigmoid(p_logits)
q_probs = torch.sigmoid(q_logits)
m = 0.5 * (p_probs + q_probs)
bce_p_m = torch.nn.functional.binary_cross_entropy(m, p_probs, reduction='none')
bce_q_m = torch.nn.functional.binary_cross_entropy(m, q_probs, reduction='none')
js_div = 0.5 * (bce_p_m + bce_q_m)
return js_div
def getNVL(obj:dict, key, default):
return default if obj.get(key) is None else obj.get(key)
class Decoder (torch.nn.Module):
def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], *args, **kwargs):
super().__init__(*args, **kwargs)
assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
self.Activation = torch.nn.ReLU#torch.nn.GELU
self.input_channels = input_channels
blocks = []
inCh = input_channels
blocks.append(torch.nn.Conv2d(inCh, inCh, kernel_size=3, stride=1, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(inCh))
blocks.append(self.Activation())
for i, outCh in enumerate(hidden_channels):
blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=4, stride=2, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(outCh))
blocks.append(self.Activation())
blocks.append(torch.nn.Conv2d(outCh, outCh, kernel_size=3, stride=1, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(outCh))
blocks.append(self.Activation())
#blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=3, stride=2, padding=1, output_padding=1))
#blocks.append(torch.nn.BatchNorm2d(outCh))
#blocks.append(self.Activation())
inCh = outCh
blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=4, stride=2, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(output_channels))
blocks.append(self.Activation())
blocks.append(torch.nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(output_channels))
blocks.append(self.Activation())
#blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
self.decoder = torch.nn.Sequential(*blocks)
def forward(self, x):
x = self.decoder(x)
return x
class DynamicDecoder(torch.nn.Module):
def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], drop_out:float = None):
super(DynamicDecoder, self).__init__()
# Ensure the number of hidden channels matches the number of blocks
assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
layers = []
in_channels = input_channels
# Loop to create the decoder blocks
for out_channels in hidden_channels:
layers.append(torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
layers.append(torch.nn.BatchNorm2d(out_channels))
layers.append(torch.nn.ReLU(inplace=True))
if drop_out is not None:
layers.append(torch.nn.Dropout(drop_out))
in_channels = out_channels
# Final layer to get the output image
layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
#layers.append(torch.nn.Sigmoid()) # To get pixel values between 0 and 1
# Combine all layers into a Sequential module
self.decoder = torch.nn.Sequential(*layers)
def forward(self, x):
return self.decoder(x)
class DecoderUnet(torch.nn.Module):
def __init__(self, input_channels=2048, output_channels=3):
super().__init__()
self.dropout = torch.nn.Dropout2d(0.4)
self.upConv_0 = torch.nn.ConvTranspose2d(input_channels, 512, kernel_size=2, stride=2)
self.decoder_0 = self._conv_block(1024, 512)
self.upConv_1 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2)
self.decoder_1 = self._conv_block(640, 320)
self.upConv_2 = torch.nn.ConvTranspose2d(320, 128, kernel_size=2, stride=2)
self.decoder_2 = self._conv_block(256, 128)
self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder_3 = self._conv_block(128, 64)
self.upConv_4 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.decoder_4 = self._conv_block(32, 16)
self.enc3_upsample = torch.nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)
self.enc2_upsample = torch.nn.ConvTranspose2d(320, 320, kernel_size=2, stride=2)
self.enc1_upsample = torch.nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
self.enc0_upsample = torch.nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
self.last_conv = torch.nn.Conv2d(16, output_channels, kernel_size=1)
def _conv_block(self, in_ch, out_ch):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
torch.nn.SiLU(inplace=True),
torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
torch.nn.SiLU(inplace=True),
)
def forward(self, x):
encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = x
decoder_l3 = self.upConv_0(btlneck)
encoder_l3 = self.enc3_upsample(encoder_l3)
decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
decoder_l3 = self.decoder_0(decoder_l3)
decoder_l3 = self.dropout(decoder_l3)
decoder_l2 = self.upConv_1(decoder_l3)
encoder_l2 = self.enc2_upsample(encoder_l2)
decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
decoder_l2 = self.decoder_1(decoder_l2)
decoder_l2 = self.dropout(decoder_l2)
decoder_l1 = self.upConv_2(decoder_l2)
encoder_l1 = self.enc1_upsample(encoder_l1)
decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
decoder_l1 = self.decoder_2(decoder_l1)
decoder_l1 = self.dropout(decoder_l1)
decoder_l0 = self.upConv_3(decoder_l1)
encoder_l0 = self.enc0_upsample(encoder_l0)
decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1)
decoder_l0 = self.decoder_3(decoder_l0)
decoder_last = self.upConv_4(decoder_l0)
decoder_last = self.decoder_4(decoder_last)
return self.last_conv(decoder_last)
class ModelsList(Enum):
RestNet50 = 1
DenseNet169 = 2
Efficientnet_b0 = 3
MobileNetV2 = 4
Inception3 = 5
ResNeXt_50_32x4d = 6
RestNet50_Droput = 7
class TypeLossFunction(Enum):
CrossEntropy = 0
BCEWithLogits = 1
HingeLoss = 2
FocalLoss = 3
ContrastiveLoss = 4
class FeatureAlignmentLoss(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, F_student, F_teacher):
fS = torch.nn.functional.normalize(F_student, p=2, dim=-1)
fT = torch.nn.functional.normalize(F_teacher, p=2, dim=-1)
cos_Sim = torch.sum(fS * fT, dim=-1)
loss = 1 - cos_Sim
return loss.mean()
class CosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
max_epochs: int,
warmup_start_lr: float = 0.00001,
eta_min: float = 0.00001,
last_epoch: int = -1,
):
"""
Args:
optimizer (torch.optim.Optimizer):
最適化手法インスタンス
warmup_epochs (int):
linear warmupを行うepoch数
max_epochs (int):
cosine曲線の終了に用いる 学習のepoch数
warmup_start_lr (float):
linear warmup 0 epoch目の学習率
eta_min (float):
cosine曲線の下限
last_epoch (int):
cosine曲線の位相オフセット
学習率をmax_epochsに至るまでコサイン曲線に沿ってスケジュールする
epoch 0からwarmup_epochsまでの学習曲線は線形warmupがかかる
https://pytorch-lightning-bolts.readthedocs.io/en/stable/schedulers/warmup_cosine_annealing.html
"""
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.warmup_start_lr = warmup_start_lr
self.eta_min = eta_min
super().__init__(optimizer, last_epoch, verbose=False)
return None
def get_lr(self):
if self.last_epoch == 0:
return [self.warmup_start_lr] * len(self.base_lrs)
if self.last_epoch < self.warmup_epochs:
return [
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
if self.last_epoch == self.warmup_epochs:
return self.base_lrs
if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
return [
group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
/ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)))
* (group["lr"] - self.eta_min)
+ self.eta_min
for group in self.optimizer.param_groups
]
class ContrastiveLoss(torch.nn.Module):
def __init__(self, margin=1.0, distance:int = 0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.TypeDistance = distance
def forward(self, output1, output2, label):
# Calcula la distancia euclidiana entre las dos salidas
Distance = torch.nn.functional.pairwise_distance(output1, output2)
# Calcula la pérdida contrastiva
loss_contrastive = torch.mean(
(1 - label) * torch.pow(Distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - Distance, min=0.0), 2)
)
return loss_contrastive
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self.hook()
def hook(self):
def forward_hook(module, Input, Output):
self.activations = Output
def backward_hook(module, grad_in, grad_out):
self.gradients = grad_out[0]
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_backward_hook(backward_hook)
def generate_cam(self, data:torch.Tensor, target_class):
self.model.zero_grad()
data.requires_grad= True
output = self.model(data).flatten()
loss = torch.nn.functional.binary_cross_entropy_with_logits(output, target_class)
loss.backward()
pooled_Grad = torch.mean(self.gradients, dim=[0, 2, 3])
for i in range (pooled_Grad.size(0)):
self.activations[:, i, :, :] *= pooled_Grad[i]
cam = torch.mean(self.activations, dim=1).squeeze()
cam = np.maximum(cam.detach().numpy(), 0)
cam = (cam - cam.min()) / (cam.max() - cam.min())
return cam
class UNet_RN18(torch.nn.Module):
def _conv_block(self, in_ch, out_ch):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
torch.nn.SiLU(inplace=True),
torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
torch.nn.SiLU(inplace=True),
)
def _hook_fn(self, module, input, output):
self.feature_maps.append(output)
def _register_encoder_hooks(self):
for layer in self.list_blocks:
self.hooks.append(layer.register_forward_hook(self._hook_fn))
def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hooks = []
self.feature_maps = []
self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
self.encoder_base.fc = torch.nn.Identity()
#for parms in self.encoder_base.parameters():
# parms.requires_grad = False
self.list_blocks = [
self.encoder_base.conv1,
self.encoder_base.layer1,
self.encoder_base.layer2,
self.encoder_base.layer3,
self.encoder_base.layer4
]
self._register_encoder_hooks()
self.dropout = torch.nn.Dropout2d(0.4)
self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.decoder_0 = self._conv_block(512, 256) #0
self.decoder_1 = self._conv_block(256, 128) #1
self.decoder_2 = self._conv_block(128, 64) #2
self.decoder_3 = self._conv_block(32 + 64, 32) #3
self.decoder_extra = self._conv_block(16, 16)
self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1)
def forward(self, x):
self.feature_maps = []
_ = self.encoder_base(x) # forward to encoder and call hooks
encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps
decoder_l3 = self.upConv_0(btlneck)
decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
decoder_l3 = self.decoder_0(decoder_l3)
decoder_l3 = self.dropout(decoder_l3)
decoder_l2 = self.upConv_1(decoder_l3)
decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
decoder_l2 = self.decoder_1(decoder_l2)
decoder_l2 = self.dropout(decoder_l2)
decoder_l1 = self.upConv_2(decoder_l2)
decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
decoder_l1 = self.decoder_2(decoder_l1)
decoder_l1 = self.dropout(decoder_l1)
decoder_l0 = self.upConv_3(decoder_l1)
decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1)
decoder_l0 = self.decoder_3(decoder_l0)
decoder_last = self.upConv_extra(decoder_l0)
decoder_last = self.decoder_extra(decoder_last)
return self.last_conv(decoder_last)
class RARP_NVB_ROI_Mask_Unet(L.LightningModule):
def __init__(self,*args, **kwargs):
super().__init__(*args, **kwargs)
self.model = UNet_RN18(in_channels=3, out_channels=1)
self.lr = 1E-4
self.Lambda_L1 = None
self.lossFN = torch.nn.BCEWithLogitsLoss()
self.train_IoU = torchmetrics.classification.BinaryJaccardIndex()
self.val_IoU = torchmetrics.classification.BinaryJaccardIndex()
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def _shared_step(self, batch, val_step:bool = True):
img, mask = batch
mask = mask.float()
mask = mask.unsqueeze(1)
prediction = self(img)
loss = self.lossFN(prediction, mask)
predicted_labels = torch.sigmoid(prediction)
if not val_step:
if self.Lambda_L1 is not None:
loss_l1 = 0
for name, params in self.model.named_parameters():
if "decoder" in name or "upConv" in name:
loss_l1 += torch.norm(params, p=1)
loss += self.Lambda_L1 * loss_l1
return loss, mask, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch, False)
self.train_IoU.update(predicted_labels, true_labels)
self.log("train_loss", loss, on_epoch=True)
self.log("train_acc_IoU", self.train_IoU, on_epoch=True, on_step=False)
return loss
def on_after_backward(self):
total_norm = 0.0
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.log("grad_norm", total_norm)
if total_norm < 1e-8:
self.log("grad_warning", "Vanishing gradient suspected!")
def on_train_epoch_start(self):
for parms in self.model.encoder_base.parameters():
parms.requires_grad = (self.current_epoch % 2 == 0)
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.val_IoU.update(predicted_labels, true_labels)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_acc_IoU", self.val_IoU, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
return [optimizer]
class RARP_NVB_ResNet50_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.resnet50()
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
pred = self.model.fc(Cont_Net)
return pred, featureMap
class RARP_NVB_VAN_CAM(L.LightningModule): #TODO
def __init__(self) -> None:
super().__init__()
self.model = van.van_b2(pretrained = True)
tempFC_ft = self.model.head.in_features
self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data, label:torch.Tensor):
cams = GradCAM(self.model, self.model.block4[-1])
featureMap = cams.generate_cam(data, label)
pred = self.model(featureMap)
return pred, featureMap
class RARP_NVB_ResNet18_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.resnet18()
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1))
Cont_Net = torch.flatten(Cont_Net)
pred = self.model.fc(Cont_Net)
return pred, featureMap
class RARP_NVB_MobileNetV2_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.mobilenet_v2()
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = self.model.features
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
pred = self.model.classifier(Cont_Net)
return pred, featureMap
class RARP_NVB_EfficientNetV2_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = self.model.features
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
pred = self.model.classifier(Cont_Net)
return pred, featureMap
class RARP_NVB_Model_BCEWithLogitsLoss(L.LightningModule):
def __init__(self, x=None) -> None:
super().__init__()
self.model = None
self.lossFN = torch.nn.BCEWithLogitsLoss() #pos_weight=torch.tensor([2.73])
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1Score = torchmetrics.F1Score('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
def forward(self, data):
data = data.float()
return self.model(data)
def _shared_step(self, batch):
img, label = batch
label = label.float()
pred = self.forward(img).flatten()
loss = self.lossFN(pred, label)
predicted_labels = torch.sigmoid(pred)
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss)
self.val_acc(predicted_labels, true_labels)
self.f1Score(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc(predicted_labels, true_labels)
self.f1ScoreTest(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
return [optimizer]
class RARP_NVB_FOCAL_loss(torch.nn.Module):
def __init__(self, alpha: float = 0.25, gamma: float = 2, reduction: str = "mean"):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torchvision.ops.focal_loss.sigmoid_focal_loss(input, target, self.alpha, self.gamma, self.reduction)
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
# Convert targets to one-hot encoding
targets = torch.nn.functional.one_hot(targets, num_classes=inputs.size(1)).float()
# Compute softmax over the inputs
probs = torch.nn.functional.softmax(inputs, dim=1)
log_probs = torch.nn.functional.log_softmax(inputs, dim=1)
# Compute the focal loss components
focal_weight = (1 - probs) ** self.gamma
loss = -self.alpha * focal_weight * targets * log_probs
# Apply reduction method
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
#TODO
class RARP_NVB_MultiClassModel(L.LightningModule):
def __init__(self,
InitWeight = torch.tensor([1,1]),
schedulerLR:bool=False,
lr:float = 1e-4,
Model:torch.nn.Module = None,
Num_Classes:int = 2,
L1:float = 1.31E-04,
L2:float = 0
) -> None:
super().__init__()
self.model = Model
self.lossFN = FocalLoss() #torch.nn.CrossEntropyLoss()
self.InitWeight = InitWeight
self.scheduler = schedulerLR
self.lr = lr
self.Lambda_L1 = L1
self.Lambda_L2 = L2
self.num_classes = Num_Classes
self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
self.test_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
self.f1ScoreTest = torchmetrics.F1Score("multiclass", num_classes=Num_Classes)
self.val_loop = False
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def _shared_step(self, batch):
img, label = batch
prediction = self(img)
predicted_labels = torch.softmax(prediction, dim=1)
loss = self.lossFN(prediction, label)
if self.Lambda_L1 is not None and not self.val_loop:
loss_l1 = 0
for params in self.model.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
self.val_loop = False
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
self.val_loop = True
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
self.val_loop = True
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
if self.scheduler:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
else:
return [optimizer]
class RARP_NVB_Model(L.LightningModule):
def __init__(self,
InitWeight = torch.tensor([1,1]),
typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy,
schedulerLR:bool=True,
lr:float = 1e-4
) -> None:
super().__init__()
self.model = None
self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
self.InitWeight = InitWeight
self.scheduler = schedulerLR
self.lr = lr
self.Lambda_L1 = None #1.31E-04 #
self.Lambda_L2 = 0
print (f"LR= {self.lr}, L1= {self.Lambda_L1}")
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
#self.f1Score = torchmetrics.F1Score('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def _shared_step(self, batch):
img, label = batch
if self.InitWeight is not None:
self.lossFN.weight = self.InitWeight[label]
label = label.float()
prediction = self(img)[:,0] #.flatten()
loss = self.lossFN(prediction, label)
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
predicted_labels = torch.sigmoid(prediction)
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
#self.f1Score.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("hp_metric", self.val_acc, on_step=False, on_epoch=True)
#self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
if self.scheduler:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
else:
return optimizer
class RARP_Ensemble(L.LightningModule):
def __init__(self, ListModels,
InitWeight = torch.tensor([1,1]),
typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy,
schedulerLR:bool=False,
lr:float = 1e-4) -> None:
super().__init__()
self.ListModels = ListModels
#for m in self.ListModels:
# m.freeze()
input_p = len(self.ListModels)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(in_features=input_p, out_features=128),
torch.nn.SiLU(),
#torch.nn.Dropout(0.1), #0.5
torch.nn.Linear(128, 1),
torch.nn.Sigmoid()
)
self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCELoss(reduction="sum")
self.InitWeight = InitWeight
self.scheduler = schedulerLR
self.lr = lr
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1Score = torchmetrics.F1Score('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
def forward(self, data):
data = data.float()
p = [m(data) for m in self.ListModels]
p = torch.cat(p, dim=1)
x = self.classifier(p)
return x
def _shared_step(self, batch):
img, label = batch
if self.InitWeight is not None:
self.lossFN.weight = self.InitWeight[label]
label = label.float()
prediction = self(img)[:,0] #.flatten()
loss = self.lossFN(prediction, label)
predicted_labels = prediction
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.f1Score.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
if self.scheduler:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, 1e-8, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
else:
return [optimizer]
class RARP_NVB_ResNet18(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_DaVit(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = timm.create_model("davit_small.msft_in1k", pretrained=True, num_classes=1)
class RARP_NVB_EfficientNetV2_Deep(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
self.model.classifier.append(torch.nn.SiLU(True))
self.model.classifier.append(torch.nn.Linear(128, 8))
self.model.classifier.append(torch.nn.SiLU(True))
self.model.classifier.append(torch.nn.Linear(8, 1))
class RARP_NVB_ResNet50_Deep(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(in_features=tempFC_ft, out_features=128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
def forward(self, img):
img = img.float()
pred = self.model(img)
return pred
class RARP_NVB_ResNet50_Deep_OPTuna(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, output_dims=[128, 8], dropuot=0.2, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
layers = []
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
inputDim = self.model.fc.in_features
for outputDim in output_dims:
layers.append(torch.nn.Linear(inputDim, outputDim))
layers.append(torch.nn.SiLU(True))
layers.append(torch.nn.Dropout(dropuot))
inputDim = outputDim
layers.append(torch.nn.Linear(inputDim, 1))
self.model.fc = torch.nn.Sequential(*layers)
def forward(self, img):
img = img.float()
pred = self.model(img)
return pred
class RARP_NVB_ResNet50_V2(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(in_features=tempFC_ft, out_features=128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8)
)
self.model.fc2 = torch.nn.Linear(8 + InputNeurons, 1)
def forward(self, data):
img, extra = data
img = img.float()
x = torch.nn.functional.silu(self.model(img), True)
extradata = torch.concat((x, extra), dim=1)
pred = self.model.fc2(extradata)
return pred
class RARP_NVB_ResNet50_V3(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
self.model.extraFC = torch.nn.Linear(InputNeurons, 128)
self.model.fc2 = torch.nn.Linear(256, 1)
def forward(self, data):
img, extra = data
img = img.float()
x = torch.nn.functional.silu(self.model(img), True)
y = torch.nn.functional.silu(self.model.extraFC(extra), True)
x = torch.concat((x, y), dim=1)
pred = self.model.fc2(x)
return pred
class RARP_NVB_ResNet50_V1(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features + InputNeurons
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.Decoder = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data):
img, extra = data
img = img.float()
featureMap = self.Decoder(img)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
Cont_Net = torch.concat((Cont_Net, extra), dim=1)
pred = self.model.fc(Cont_Net)
return pred
class RARP_NVB_VAN(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = van.van_b2(pretrained = True)
tempFC_ft = self.model.head.in_features
self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_ResNet50(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=512, bottleneck=256, n_layers=3, norm_last_layer=False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.activationFN = torch.nn.GELU()
if n_layers == 1:
self.mlp = torch.nn.Linear(in_dim, bottleneck)
else:
layers = [torch.nn.Linear(in_dim, hidden_dim)]
layers.append(self.activationFN)
layers.append(torch.nn.Dropout(0.30))
for _ in range(n_layers - 2):
layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
layers.append(self.activationFN)
layers.append(torch.nn.Linear(hidden_dim, bottleneck))
self.mlp = torch.nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = torch.nn.utils.weight_norm(
torch.nn.Linear(bottleneck, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = torch.nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
class RARP_NVB_DINO_Wrapper(torch.nn.Module):
def __init__(self, backbone:torch.nn.Module, new_head:torch.nn.Module, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.backbone = backbone
self.head = new_head
def forward(self, x):
if isinstance(x, list):
n_crops = len(x)
concatCrops = torch.cat(x, dim=0)
else:
concatCrops = x
n_crops = 1
embedding = self.backbone(concatCrops)
logitis = self.head(embedding)
chunks = logitis.chunk(n_crops)
return chunks
class RARP_NVB_DINO_Loss(torch.nn.Module):
def __init__(self, out_dim:int, teacher_Thao:float = 0.04, student_Thao:float = 0.1, center_momentum:float = 0.9, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.S_Thao = student_Thao
self.T_Thao = teacher_Thao
self.C_Momentum = center_momentum
self.register_buffer("center", torch.zeros(1, out_dim))
def forward(self, s_Output, t_Output, validation_step:bool = False):
sTemp = [s / self.S_Thao for s in s_Output]
tTemp = [(t - self.center) / self.T_Thao for t in t_Output]
studentSM = [torch.nn.functional.log_softmax(s, dim=-1) for s in sTemp]
teacherSM = [torch.nn.functional.softmax(t, dim=-1).detach() for t in tTemp]
total_loss = 0
n_loss_terms = 0
for t_ix, t in enumerate(teacherSM):
for s_ix, s in enumerate(studentSM):
if (t_ix == s_ix) and (len(teacherSM) > 1):
continue
loss = torch.sum(-t * s, dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
if not validation_step:
self.update_center(t_Output)
return total_loss
@torch.no_grad()
def update_center(self, t_output):
b = torch.cat(t_output).mean(dim=0, keepdim=True)
self.center = self.center * self.C_Momentum + b * (1 - self.C_Momentum)
class RARP_NVB_DINO_RestNet50_Deep(L.LightningModule):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
L1:float = None,
L2:float = 0,
) -> None:
super().__init__()
self.lr = lr
self.Lambda_L1 = L1
self.Lambda_L2 = L2
self.threshold = threshold
self.momentum_teacher = momentum_teacher
self.out_dim = 512
self.in_dim = 2048
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
self.teacher_Labels = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
self.student = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) #torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
self.teacher_Features = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
self.student.fc = torch.nn.Identity()
self.teacher_Features.fc = torch.nn.Identity()
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
for parms in self.teacher_Labels.model.parameters():
parms.requires_grad = False
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, 0.04, 0.1, momentum_teacher)
self.lossFN_KD = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
#self.lossFH_KLDiv = torch.nn.KLDivLoss(reduction="batchmean")
self.clasiffier = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(self.out_dim, 128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
def forward(self, data, val_step:bool = True):
if val_step:
data = data.float()
dataClassificator, dataTeacher, dataStudent = data, data, data
else:
data = [d.float() for d in data]
dataClassificator, dataTeacher, dataStudent = data[0], data[1:3], data
TeacherDino = self.teacher_Features(dataTeacher)
TeacherLabels = self.teacher_Labels(dataClassificator)
Student = self.student(dataStudent)
# es se evaluan todas las salidas del estuidaitne
#if isinstance(dataStudent, list):
# #index = np.random.randint(0, len(dataStudent))
# temp = self.student(dataStudent)
# CatS_Classifier = torch.cat(temp, dim=0)
# meanS_Classifier = torch.zeros(self.in_dim)
# for dataS in temp:
# meanS_Classifier += dataS
# #S_Classifier = self.student(dataStudent[index])
# S_Classifier = meanS_Classifier / len(dataStudent)
#else:
# S_Classifier = Student
Cont_Net = torch.cat(Student, dim=0)
pred = self.clasiffier(Cont_Net)
if not val_step:
TeacherLabels = [self.teacher_Labels(dataClassificator) for _ in range(len(dataStudent))]
TeacherLabels = torch.cat(TeacherLabels, dim=0)
TeacherLabelsPred = torch.sigmoid(TeacherLabels.flatten())
PseudoLabels = (TeacherLabelsPred > self.threshold) * 1.0
return (pred.flatten(), PseudoLabels, TeacherLabels.flatten()), (TeacherDino, Student)
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
if not val_step:
label = torch.cat([label for _ in range(len(img))], dim=0)
label = label.float()
KD_Prediction, DINO_Loss = self(img, val_step)
TeacherF, StudentF = DINO_Loss
prediction, PseudoLabels, teacherOutputs = KD_Prediction
predicted_labels = torch.sigmoid(prediction)
##verstion 1
W_Alpha, W_Beta = (1, 0.5)#(1, 0.5)
loss = W_Alpha * self.lossFN_KD(prediction, PseudoLabels) + W_Beta * self.lossFN_KD(prediction, label)
#version 2
#thao_KD = 1#5.0
#W_Alpha, W_Beta = (0.6, 0.4)
#softTeacher = torch.sigmoid(teacherOutputs/thao_KD)
#softStudent = torch.sigmoid(prediction/thao_KD)
#loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
#loss_hl = self.lossFN_KD(prediction, label)
#loss = W_Alpha * loss_hl + W_Beta * loss_sl
#loss = W_Alpha * self.lossFN_KD(prediction, label) + W_Beta * (self.lossFH_KLDiv(softStudent, softTeacher) * (thao_KD ** 2))
loss += (self.lossFN_DINO(StudentF, TeacherF) if not val_step else 0)
if not val_step:
self.logger.experiment.add_histogram ("Teacher", TeacherF[0])
self.logger.experiment.add_histogram ("Student", StudentF[1])
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.student.parameters(): # aqui
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
#return loss, PseudoLabels, predicted_labels
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
teacher_ps.data.mul_(self.momentum_teacher)
teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
return [optimizer]
class RARP_NVB_DINO_VAN(RARP_NVB_DINO_RestNet50_Deep):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher: float = 0.9995,
lr: float = 0.0001,
L1: float = None,
L2: float = 0
) -> None:
super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
self.in_dim = 512
self.student = van.van_b2(pretrained = True, num_classes = -1)
self.teacher_Features = van.van_b2(pretrained = True, num_classes = -1)
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
class RARP_NVB_DINO_ViT(RARP_NVB_DINO_RestNet50_Deep):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher: float = 0.9995,
lr: float = 0.0001,
L1: float = None,
L2: float = 0
) -> None:
super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
self.in_dim = 768
self.student = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
self.teacher_Features = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
self.student.heads = torch.nn.Identity()
self.teacher_Features.heads = torch.nn.Identity()
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
class RARP_NVB_Classification_Head(torch.nn.Module):
def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.activation = activation_fn
if len (layer) == 0:
self.head = torch.nn.Linear(in_features, out_features)
else:
temp_head = []
next_input = in_features
for num in layer:
temp_head.append(torch.nn.Linear(next_input, num))
temp_head.append(self.activation)
temp_head.append(torch.nn.Dropout(0.4))
next_input = num
temp_head[-1] = torch.nn.Dropout(0.2)
temp_head.append(torch.nn.Linear(next_input, out_features))
self.head = torch.nn.Sequential(*temp_head)
del temp_head
def forward(self, x):
return self.head(x)
class RARP_Encoder_DINO(L.LightningModule):
def __init__(self,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
Teacher_T:float = 0.04,
Student_T:float = 0.1,
max_epochs:int = 100,
total_steps:int = None
) -> None:
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.momentum_teacher = momentum_teacher
self.out_dim = 65536
self.in_dim = 512
self.student = van.van_b1(num_classes = 0)
self.teacher = van.van_b1(num_classes = 0)
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True)
)
self.teacher = RARP_NVB_DINO_Wrapper(
self.teacher,
RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True)
)
self.teacher.load_state_dict(self.student.state_dict())
for parms in self.teacher.parameters():
parms.requires_grad = False
self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
def forward(self, data, val_step=False):
if val_step:
dataTeacher, dataStudent = data.float(), data.float()
else:
data = [d.float() for d in data]
dataTeacher, dataStudent = data[0:3], data
teacher_features = self.teacher(dataTeacher)
student_features = self.student(dataStudent)
return teacher_features, student_features
def _shared_step(self, batch, val_step=False):
img, _ = batch
t, s = self(img, val_step)
loss_Dino = self.lossFN_DINO(s, t, validation_step=val_step)
return loss_Dino
def training_step(self, batch, batch_idx):
loss = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
loss = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
def on_train_batch_end(self, outputs, batch, batch_idx):
#step = self.global_step
#
#m = 1.0 - (1.0 - self.hparams.momentum_teacher) * (
# (1 + math.cos(math.pi * step / self.hparamstotal_steps)) / 2
#)
with torch.no_grad():
for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher.parameters()):
teacher_ps.data.mul_(self.momentum_teacher)
teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
def on_after_backward(self):
total_norm = 0.0
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.log("grad_norm", total_norm)
if total_norm < 1e-8:
self.log("grad_warning", "Vanishing gradient suspected!")
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
# optimizer,
# T_max=self.hparams.max_epochs, # decays from epoch 0 → epoch max_epochs
# eta_min=0.0
#)
#return {
# "optimizer": optimizer,
# "lr_scheduler": {
# "scheduler": scheduler,
# "interval": "epoch", # <-- step once per epoch
# "frequency": 1,
# },
#}
return [optimizer]
class RARP_NVB_DINO_MultiTask(L.LightningModule):
# Define a hook function to capture the output
def _hook_fn_Student(self, module, input, output):
self.last_conv_output_S = output
def _hook_fn_Teacher(self, module, input, output):
self.last_conv_output_T = output
def __init__(
self,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
L1:float = None,
L2:float = 0,
std: float = None,
mean: float = None,
SoftAdptAlgo:int = 0,
SoftAdptBeta:float = 0.1,
Teacher_T:float = 0.04,
Student_T:float = 0.1,
intermittent:bool = False
) -> None:
super().__init__()
self.intermittent_train = intermittent
self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
self.lr = lr
self.Lambda_L1 = L1
self.Lambda_L2 = L2
self.Teacher_t = Teacher_T
self.Studet_T = Student_T
self.momentum_teacher = momentum_teacher
self.out_dim = 1024
self.in_dim = 512
self.weights = torch.tensor([1,1,1])
self.softAdapt = NormalizedSoftAdapt(SoftAdptBeta) if SoftAdptAlgo == 1 else LossWeightedSoftAdapt(SoftAdptBeta)
self.loss_history = {
'loss_DINO': [],
'loss_Reconstruction': [],
'loss_Binary': [],
}
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
#self.student = van.van_b2(pretrained = True, num_classes = 0)
#self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0)
self.student = torchvision.models.convnext_small(weights=torchvision.models.ConvNeXt_Small_Weights.DEFAULT)
self.student.classifier[-1] = torch.nn.Identity()
self.teacher_Features = torchvision.models.convnext_small(weights=torchvision.models.ConvNeXt_Small_Weights.DEFAULT)
self.teacher_Features.classifier[-1] = torch.nn.Identity()
self.decoder = DynamicDecoder(input_channels=1024)
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
)
self.teacher_Features.load_state_dict(self.student.state_dict())
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
self.ReconstructionLoss = torch.nn.MSELoss()
self.last_conv_output_T = None
self.last_conv_output_S = None
self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher)
self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
self.clasiffier = torch.nn.Sequential(
torch.nn.Linear(1024, 128),
torch.nn.SiLU(True),
torch.nn.Dropout(0.4),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Dropout(0.2),
torch.nn.Linear(8, 1)
)
print(f"lr={self.lr}, L1={self.Lambda_L1}")
def _denormalize(self, tensor:torch.Tensor):
# Move mean and std to the same device as the input tensor
mean = self.mean_IMG.to(tensor.device)
std = self.std_IMG.to(tensor.device)
return tensor * std + mean
def _calc_L1(self, params):
l1 = 0
for p in params:
l1 += torch.sum(torch.abs(p))
return self.Lambda_L1 * l1
def _calc_weights(self, log_weights:bool = True):
self.weights = self.softAdapt.get_component_weights(
torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]),
torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_Reconstruction"]),
torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
verbose=False
)
if log_weights:
self.log("W_loss_img", self.weights[1], on_epoch=True, on_step=False)
self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False)
self.log("W_loss_GT", self.weights[2], on_epoch=True, on_step=False)
self.loss_history = {
'loss_DINO': [],
'loss_Reconstruction': [],
'loss_Binary': [],
}
def forward(self, data, val_step:bool = True):
if val_step:
data = data.float()
dataTeacher, dataStudent = data, data
else:
data = [d.float() for d in data]
dataTeacher, dataStudent = data[1:3], data
TeacherDino = self.teacher_Features(dataTeacher)
Student = self.student(dataStudent)
if not val_step:
NumChunks = len(dataStudent)
S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
reconstructed_image = self.decoder(cat_features)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
pred = self.clasiffier(Cont_Net)
return pred, (Student, TeacherDino), reconstructed_image
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
prediction, features, new_image = self(img, val_step)
StudentF, TeacherF = features
if isinstance(self.clasiffier, torch.nn.Sequential):
if self.clasiffier[-1].out_features == 1:
prediction = prediction.flatten()
elif isinstance(self.clasiffier, (NOAH, RARP_NVB_Classification_Head)):
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
#DINO Loss
loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
#Clasificator
loss_HL = self.lossFN(prediction, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
if not val_step:
if self.Lambda_L1 is not None:
loss_HL += self._calc_L1(self.clasiffier.parameters())
if self.Lambda_L2 > 0:
l2_reg = 0.0
for param in self.clasiffier.parameters():
l2_reg += torch.norm(param, 2) ** 2
loss_HL += self.Lambda_L2 * l2_reg
self.loss_history["loss_DINO"].append(loss_Dino.item())
self.loss_history["loss_Reconstruction"].append(loss_img.item())
self.loss_history["loss_Binary"].append(loss_HL.item())
loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
def on_train_epoch_start(self):
if self.current_epoch % 2 == 0 and self.current_epoch != 0:
self._calc_weights()
if self.intermittent_train and self.current_epoch != 0:
par_epoch = (self.current_epoch % 2 == 0)
for parms in self.student.backbone.parameters():
parms.requires_grad = par_epoch
for parms in self.decoder.parameters():
parms.requires_grad = not par_epoch
for parms in self.clasiffier.parameters():
parms.requires_grad = not par_epoch
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("train_loss_DINO", losses[0], on_epoch=True, on_step=False)
self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
return loss
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
teacher_ps.data.mul_(self.momentum_teacher)
teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
#self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
def on_after_backward(self):
total_norm = 0.0
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.log("grad_norm", total_norm)
if total_norm < 1e-8:
self.log("grad_warning", "Vanishing gradient suspected!")
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("val_loss_DINO", losses[0], on_epoch=True, on_step=False)
self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels, losses = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
if self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
imgOrig = torch.clip(self._denormalize(batch[0])/255, 0, 1)
imgOrig = imgOrig[:, [2, 1, 0], :, :]
imgReconstruction = torch.cat((imgOrig, imgReconstruction), dim=0)
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) #, weight_decay=self.Lambda_L2
return [optimizer]
class RARP_Hybrid_TS_LR(torch.nn.Module):
def __init__(
self,
base_TS_Model:str = "",
std: float = None,
mean: float = None,
stretch: bool = False,
masked: bool = False
):
super().__init__()
self.mean = mean
self.std = std
self.stretch = stretch
self.masked = masked
self.mid_length = 0
self.labels = ["L", "R"]
self.baseModel = RARP_NVB_DINO_MultiTask.load_from_checkpoint(base_TS_Model)
self.baseModel.eval()
def _mask_LR(self, image:torch.Tensor, Left:bool= True):
halfImg = image[:, :, :, :self.mid_length] if Left else image[:, :, :, self.mid_length:]
pad_zeros = torch.zeros_like(halfImg) #Agv. Color
listImgs = [halfImg, pad_zeros] if Left else [pad_zeros, halfImg]
return torch.cat(listImgs, dim=-1)
def _crop_LR(self, image:torch.Tensor, Left:bool = True):
if Left:
return image[:, :, :, :self.mid_length] if not self.stretch else torch.nn.functional.interpolate(
image[:, :, :, :self.mid_length],
size=(224, 224),
mode='bicubic',
align_corners=False
)
else:
return image[:, :, :, self.mid_length:] if not self.stretch else torch.nn.functional.interpolate(
image[:, :, :, self.mid_length:],
size=(224, 224),
mode='bicubic',
align_corners=False
)
def forward(self, x):
_, _, _, w = x.shape #[B, C, H, W]
self.mid_length = w // 2
LR_Img = {
"L":self._crop_LR(x, True) if not self.masked else self._mask_LR(x, True),
"R":self._crop_LR(x, False) if not self.masked else self._mask_LR(x, False)
}
pred = []
for label in self.labels:
with torch.no_grad():
raw_pred, _, _ = self.baseModel(LR_Img[label])
pred.append(raw_pred)
return torch.cat(pred, dim=-1)
#Ablation Models
"""T-S Multi-task model With out Recostruccion (V3R1_A1)
Returns:
LightningModule
"""
class RARP_NVB_DINO_MultiTask_A1(RARP_NVB_DINO_MultiTask):
def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
self.decoder = torch.nn.Identity()
del self.ReconstructionLoss
self.mean_IMG = None
self.std_IMG = None
self.weights = torch.tensor([1,1])
self.loss_history = {
'loss_DINO': [],
'loss_Binary': [],
}
def _calc_weights(self, log_weights:bool = True):
self.weights = self.softAdapt.get_component_weights(
torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]),
torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
verbose=False
)
if log_weights:
self.log("W_loss_img", 0, on_epoch=True, on_step=False)
self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False)
self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False)
self.loss_history = {
'loss_DINO': [],
'loss_Binary': [],
}
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
prediction, features, new_image = self(img, val_step)
StudentF, TeacherF = features
if isinstance(self.clasiffier, torch.nn.Sequential):
if self.clasiffier[-1].out_features == 1:
prediction = prediction.flatten()
elif isinstance(self.clasiffier, NOAH):
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
#orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
#DINO Loss
loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
#Clasificator
loss_HL = self.lossFN(prediction, label)
#Reconstruction
#loss_img = self.ReconstructionLoss(new_image, orignalImg)
#loss_img = loss_img.float()
if not val_step:
if self.Lambda_L1 is not None:
loss_HL += self._calc_L1(self.clasiffier.parameters())
if self.Lambda_L2 > 0:
l2_reg = 0.0
for param in self.clasiffier.parameters():
l2_reg += torch.norm(param, 2) ** 2
loss_HL += self.Lambda_L2 * l2_reg
self.loss_history["loss_DINO"].append(loss_Dino.item())
self.loss_history["loss_Binary"].append(loss_HL.item())
loss = self.weights[0] * loss_Dino + self.weights[1] * loss_HL
return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[1] * loss_HL, 0, new_image)
"""T-S Multi-task model With out SoftAdadapt, Fix loss wegth 0.333_ (V3R1_A2)
Returns:
LightningModule
"""
class RARP_NVB_DINO_MultiTask_A2(RARP_NVB_DINO_MultiTask):
def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
del self.softAdapt
self.weights = [1/3, 1/3, 1/3]
def _calc_weights(self):
self.weights = [1/3, 1/3, 1/3]
self.loss_history = {
'loss_DINO': [],
'loss_Reconstruction': [],
'loss_Binary': [],
}
"""S Multi-task model No Dino base encoder VAN_b2, (V3R1_A3_1)
Returns:
LightningModule
"""
class RARP_NVB_DINO_MultiTask_A3(RARP_NVB_DINO_MultiTask):
def _hook_fn_Student(self, module, input, output):
self.last_conv_output_S = output
self.last_conv_output_T = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
if not self.val_phace:
self.last_conv_output_T = self.last_conv_output_T[:16] # Fixed bach of 8
def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
self.student = RARP_NVB_DINO_Wrapper(
van.van_b2(pretrained = True, num_classes = 0),
torch.nn.Identity()
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
torch.nn.Identity(),
torch.nn.Identity()
)
self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
self.val_phace = True
self.weights = torch.tensor([1,1])
self.loss_history = {
'loss_Reconstruction': [],
'loss_Binary': [],
}
def _calc_weights(self, log_weights:bool = True):
self.weights = self.softAdapt.get_component_weights(
torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_DINO"]),
torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
verbose=False
)
if log_weights:
self.log("W_loss_img", self.weights[0], on_epoch=True, on_step=False)
self.log("W_loss_DINO", 0, on_epoch=True, on_step=False)
self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False)
self.loss_history = {
'loss_Reconstruction': [],
'loss_Binary': [],
}
def _shared_step(self, batch, val_step:bool = False):
self.val_phace = val_step
img, label = batch
prediction, features, new_image = self(img, val_step)
_, TeacherF = features
if isinstance(self.clasiffier, torch.nn.Sequential):
if self.clasiffier[-1].out_features == 1:
prediction = prediction.flatten()
elif isinstance(self.clasiffier, NOAH):
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
#DINO Loss
#loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
#Clasificator
loss_HL = self.lossFN(prediction, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
if not val_step:
if self.Lambda_L1 is not None:
loss_HL += self._calc_L1(self.clasiffier.parameters())
if self.Lambda_L2 > 0:
l2_reg = 0.0
for param in self.clasiffier.parameters():
l2_reg += torch.norm(param, 2) ** 2
loss_HL += self.Lambda_L2 * l2_reg
self.loss_history["loss_Reconstruction"].append(loss_img.item())
self.loss_history["loss_Binary"].append(loss_HL.item())
loss = self.weights[0] * loss_img + self.weights[1] * loss_HL
return loss, label, predicted_labels, (0, self.weights[1] * loss_HL, self.weights[0] * loss_img, new_image)
"""S Multi-task model No Dino base encoder RN50, (V3R1_A3_2)
Returns:
LightningModule
"""
class RARP_NVB_DINO_MultiTask_A3_RN50(RARP_NVB_DINO_MultiTask):
def _hook_fn_Student(self, module, input, output):
self.last_conv_output_S = output
self.last_conv_output_T = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
if not self.val_phace:
self.last_conv_output_T = self.last_conv_output_T[:16] # Fixed bach of 8
self.last_conv_output_T = self.last_conv_output_T[:, :0, :, :]
def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
self.student = RARP_NVB_DINO_Wrapper(
torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT),
torch.nn.Identity()
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
torch.nn.Identity(),
torch.nn.Identity()
)
self.student.backbone.layer4.register_forward_hook(self._hook_fn_Student)
self.val_phace = True
self.weights = torch.tensor([1,1])
self.loss_history = {
'loss_Reconstruction': [],
'loss_Binary': [],
}
self.decoder = DynamicDecoder(input_channels=2048)
self.clasiffier = torch.nn.Sequential(
torch.nn.Linear(2048, 128),
torch.nn.SiLU(True),
torch.nn.Dropout(0.4),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Dropout(0.2),
torch.nn.Linear(8, 1)
)
def _calc_weights(self, log_weights:bool = True):
self.weights = self.softAdapt.get_component_weights(
torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_DINO"]),
torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
verbose=False
)
if log_weights:
self.log("W_loss_img", self.weights[0], on_epoch=True, on_step=False)
self.log("W_loss_DINO", 0, on_epoch=True, on_step=False)
self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False)
self.loss_history = {
'loss_Reconstruction': [],
'loss_Binary': [],
}
def _shared_step(self, batch, val_step:bool = False):
self.val_phace = val_step
img, label = batch
prediction, features, new_image = self(img, val_step)
_, TeacherF = features
if isinstance(self.clasiffier, torch.nn.Sequential):
if self.clasiffier[-1].out_features == 1:
prediction = prediction.flatten()
elif isinstance(self.clasiffier, NOAH):
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
#DINO Loss
#loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
#Clasificator
loss_HL = self.lossFN(prediction, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
if not val_step:
if self.Lambda_L1 is not None:
loss_HL += self._calc_L1(self.clasiffier.parameters())
if self.Lambda_L2 > 0:
l2_reg = 0.0
for param in self.clasiffier.parameters():
l2_reg += torch.norm(param, 2) ** 2
loss_HL += self.Lambda_L2 * l2_reg
self.loss_history["loss_Reconstruction"].append(loss_img.item())
self.loss_history["loss_Binary"].append(loss_HL.item())
loss = self.weights[0] * loss_img + self.weights[1] * loss_HL
return loss, label, predicted_labels, (0, self.weights[1] * loss_HL, self.weights[0] * loss_img, new_image)
"""T-S Multi-task model, classification head layer change, (V3R1_A4)
Returns:
LightningModule
"""
class RARP_NVB_DINO_MultiTask_A4(RARP_NVB_DINO_MultiTask):
def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
#layeres = [128, 8] #L3 original
#layeres = [256, 128, 8] #L4
#layeres = [8] #L2
layeres = [] #L1
self.clasiffier = RARP_NVB_Classification_Head(1024, 1, layeres, torch.nn.SiLU(True))
#end Ablation Models
class RARP_CLIP_loss(torch.nn.Module):
def __init__(self, temperature):
super().__init__()
self.temp = temperature
def forward(self, z_s:torch.Tensor, z_t:torch.Tensor):
logits = torch.matmul(z_s, z_t.t()) / self.temp
lables = torch.arange(z_s.size(0), device=logits.device)
loss_s2t = torch.nn.functional.cross_entropy(logits, lables)
loss_t2s = torch.nn.functional.cross_entropy(logits.t(), lables)
return 0.5 * (loss_s2t + loss_t2s)
class RARP_CLIP(L.LightningModule):
def __init__(
self,
student_backbone: str = "",
teacher_backbone: str = "",
proj_dim: int = 256,
embeddings: int = 512,
temperature: float = 0.07,
lr: float = 1e-4,
):
super().__init__()
self.save_hyperparameters()
match(student_backbone):
case "van_b1":
self.student = van.van_b1(pretrained=False, num_classes=0)
self.student_dim = 512
case _:
raise Exception(f"{student_backbone} Not Implemented")
if len(teacher_backbone) > 0:
self.teacher = van.van_b2(pretrained=False, num_classes=0)
self.teacher.load_state_dict(torch.load(teacher_backbone))
self.teacher_dim = 512
else:
self.teacher = van.van_b2(pretrained=True, num_classes=0)
self.teacher_dim = 512
for p in self.teacher.parameters():
p.requires_grad = False
self.proj_s = torch.nn.Sequential(
torch.nn.Linear(self.student_dim, proj_dim),
torch.nn.LayerNorm(proj_dim),
torch.nn.GELU(),
torch.nn.Linear(proj_dim, embeddings)
)
self.loss_fn = RARP_CLIP_loss(temperature)
def forward(self, data):
x_s = self.student(data)
x_s = self.proj_s(x_s)
x_s = torch.nn.functional.normalize(x_s, dim=-1)
x_t = self.teacher(data)
x_t = torch.nn.functional.normalize(x_t, dim=-1)
return x_s, x_t
def _shared_step(self, batch):
img, _ = batch
z_s, z_t = self(img)
loss = self.loss_fn(z_s, z_t)
return loss
def training_step(self, batch, batch_idx):
loss = self._shared_step(batch)
self.log("train/clip_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def configure_optimizers(self):
params = (
list(self.student.parameters()) +
list(self.proj_s.parameters())
#list(self.proj_t.parameters()) # now included
)
return torch.optim.AdamW(params, lr=self.hparams.lr)
class DecoderBlock(torch.nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv_expand = torch.nn.Conv2d(in_ch, out_ch*4, 3, padding=1)
self.pixel_shuffle = torch.nn.PixelShuffle(2)
self.bn1 = torch.nn.BatchNorm2d(out_ch)
self.act1 = torch.nn.GELU()
self.conv_refine = torch.nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bn2 = torch.nn.BatchNorm2d(out_ch)
self.act2 = torch.nn.GELU()
def forward(self, x):
x = self.conv_expand(x)
x = self.pixel_shuffle(x)
x = self.bn1(x)
x = self.act1(x)
x = self.conv_refine(x)
x = self.bn2(x)
x = self.act2(x)
return x
class DynamicDecoder_PixelShuffle(torch.nn.Module):
def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], drop_out:float = None):
super().__init__()
# Ensure the number of hidden channels matches the number of blocks
assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
layers = []
in_channels = input_channels
# Loop to create the decoder blocks
for out_channels in hidden_channels:
layers.append(DecoderBlock(in_channels, out_channels))
if drop_out is not None:
layers.append(torch.nn.Dropout(drop_out))
in_channels = out_channels
# Final layer to get the output image
#layers.append(torch.nn.Conv2d(in_channels, output_channels, kernel_size=3, padding=1))
layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
#layers.append(torch.nn.Sigmoid()) # To get pixel values between 0 and 1
# Combine all layers into a Sequential module
self.decoder = torch.nn.Sequential(*layers)
def forward(self, x):
return self.decoder(x)
class RARP_NVB_DINO_MultiTask_Pretrain(RARP_NVB_DINO_MultiTask):
def __init__(
self,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher = 0.9995,
lr = 0.0001,
L1 = None,
L2 = 0,
std = None,
mean = None,
SoftAdptAlgo = 0,
SoftAdptBeta = 0.1,
Teacher_T = 0.04,
Student_T = 0.1,
intermittent = False,
pre_train_pth:str = "",
HParameter = {},
):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
#self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04)
#self.lr = getNVL(HParameter, "lr", 1.0E-4)
self.student = van.van_b2(pretrained = False, num_classes = 0)
self.teacher_Features = van.van_b2(pretrained = False, num_classes = 0)
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
)
if len(pre_train_pth) > 0:
self.student.backbone.load_state_dict(torch.load(pre_train_pth))
self.teacher_Features.load_state_dict(self.student.state_dict())
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher)
self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
self.decoder = DynamicDecoder(input_channels=1024)
class RARP_MAE(L.LightningModule):
def __init__(
self,
backbone:str,
mask_ratio:float = 0.75,
patch_size: int = 16,
#img_size: int = 224,
lr: float = 1e-4,
hiden_channels = [512, 256, 128, 64, 32],
img_mean_std = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]],
activation_fn = torch.nn.ReLU(True)
):
super().__init__()
self.save_hyperparameters(ignore=["activation_fn", "img_mean_std"])
if img_mean_std is not None:
mean, std = img_mean_std
self.register_buffer("mean_IMG", torch.tensor(mean).view(3, 1, 1))
self.register_buffer("std_IMG", torch.tensor(std).view(3, 1, 1))
else:
self.mean_IMG = None
self.std_IMG = None
self.loss_fn = torch.nn.L1Loss()
match backbone:
case "resnet":
model = torchvision.models.resnet18()
model.fc = torch.nn.Identity()
self.encoder = model
self.encoder_out_dim = 512
case "van":
model = van.van_b1()
model.head = torch.nn.Identity()
self.encoder = model
self.encoder_out_dim = 512
case "levit":
self.encoder = timm.create_model("levit_192.fb_dist_in1k", pretrained=False, num_classes=0)
self.encoder_out_dim = 384
hiden_channels = [384, 256, 128, 64, 32]
case "van_l1_loss":
model = van.van_b1()
model.head = torch.nn.Identity()
self.encoder = model
self.encoder_out_dim = 512
self.loss_fn = torch.nn.L1Loss()
case "van_2":
model = van.van_b2(num_classes=0)
self.encoder = model
self.encoder_out_dim = 512
in_channel = self.encoder_out_dim
layers = [
torch.nn.Linear(in_channel, hiden_channels[0]*7*7),
torch.nn.Unflatten(1, (hiden_channels[0], 7, 7)),
]
for out_channel in hiden_channels[1:]:
layers.append(torch.nn.ConvTranspose2d(in_channel, out_channel, kernel_size=3, stride=2, padding=1, output_padding=1))
layers.append(torch.nn.BatchNorm2d(out_channel))
layers.append(activation_fn)
in_channel = out_channel
#Last layer
layers.append(torch.nn.ConvTranspose2d(in_channel, 3, kernel_size=3, stride=2, padding=1, output_padding=1))
self.decoder = torch.nn.Sequential(*layers)
def _denormalize(self, tensor:torch.Tensor):
return tensor * self.std_IMG + self.mean_IMG
def _mask_patches(self, imgs):
B, C, H, W = imgs.shape
ph, pw = self.hparams.patch_size, self.hparams.patch_size
gh, gw = H // ph, W // pw
mask:torch.Tensor = (torch.rand(B, gh * gw, device=imgs.device) >= self.hparams.mask_ratio)
mask = mask.reshape(B, 1, gh, gw) # [B,1,gh,gw]
# expand mask to full image
mask = mask.repeat_interleave(ph, dim=2) # [B,1,gh*ph,gw]
mask = mask.repeat_interleave(pw, dim=3) # [B,1,gh*ph,gw*pw] == [B,1,H,W]
mask = mask.expand(-1, C, -1, -1)
imgs_masked = imgs * mask
return imgs_masked
def forward(self, data, val_step=False):
if not val_step:
data = [d.float() for d in data] #[0] original Img, [1] augmented Img
else:
data = [data.float(), data.float()]
imgs_masked = self._mask_patches(data[1])
latent = self.encoder(imgs_masked) # [B, 512]
reconstructed_image = self.decoder(latent) # [B, 3, 224, 224]
return reconstructed_image, data[0]
def _shared_step(self, batch, val_step=False):
img, _ = batch
res, true_img = self(img, val_step)
loss_MSE = self.loss_fn(res, true_img)
return loss_MSE, res, true_img
def training_step(self, batch, batch_idx):
loss, img, _ = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True, sync_dist=True)
if batch_idx % 500 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(img), 0, 1)
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
return loss
def validation_step(self, batch, batch_idx):
loss, img, target_img = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
if self.mean_IMG is not None and self.std_IMG is not None:
img = torch.clip(self._denormalize(img), 0, 1)
target_img = torch.clip(self._denormalize(target_img), 0, 1)
ssim = piq.ssim(img, target_img, data_range=1.0)
psnr = piq.psnr(img, target_img, data_range=1.0)
self.log("val_ssim", ssim, on_epoch=True)
self.log("val_psnr", psnr, on_epoch=True)
if batch_idx % 100 == 0:
grid = torchvision.utils.make_grid(img)
self.logger.experiment.add_image('val_reconstructed_images', grid, self.global_step)
def on_after_backward(self):
total_norm = 0.0
for p in self.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.log("grad_norm", total_norm)
if total_norm < 1e-8:
self.log("grad_warning", "Vanishing gradient suspected!")
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
return [optimizer]
class RARP_NVB_DINO_MultiTask_LeViT(RARP_NVB_DINO_MultiTask):
def _hook_fn_Student(self, module, input, output):
B, N, C = output.shape
H = W = int (N**0.5)
self.last_conv_output_S = output.transpose(1, 2)
self.last_conv_output_S = self.last_conv_output_S.contiguous().view(B, C, H, W)
def _hook_fn_Teacher(self, module, input, output):
B, N, C = output.shape
H = W = int (N**0.5)
self.last_conv_output_T = output.transpose(1, 2)
self.last_conv_output_T = self.last_conv_output_T.contiguous().view(B, C, H, W)
def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent)
self.in_dim = 768
self.out_dim = 2048
self.student = timm.create_model("levit_384.fb_dist_in1k", pretrained=True, num_classes=0)
self.teacher_Features = timm.create_model("levit_384.fb_dist_in1k", pretrained=True, num_classes=0)
self.decoder = DynamicDecoder(input_channels = 1024)
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=3)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=3)
)
self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher)
self.teacher_Features.load_state_dict(self.student.state_dict())
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
self.teacher_Features.backbone.stages[-2].register_forward_hook(self._hook_fn_Teacher)
self.student.backbone.stages[-2].register_forward_hook(self._hook_fn_Student)
self.clasiffier = RARP_NVB_Classification_Head(1024, 1, [128, 8], torch.nn.SiLU(True))
class RARP_NVB_DINO_MultiTask_Unet(RARP_NVB_DINO_MultiTask):
def _encoder_hool_fn(self, module, input, output):
self.feature_maps.append(output)
def _register_encoder_hooks(self, block_list:list):
for layer in block_list:
self.hooks.append(layer.register_forward_hook(self._encoder_hool_fn))
def __init__(
self,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
L1:float = None,
L2:float = 0,
std: float = None,
mean: float = None,
SoftAdptAlgo:int = 0,
SoftAdptBeta:float = 0.1
):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
self.hooks = []
self.feature_maps = []
self.list_blocks = [
self.student.backbone.block1[-1],
self.student.backbone.block2[-1],
self.student.backbone.block3[-1],
self.student.backbone.block4[-2],
]
#self.list_blocks_T = [
# self.teacher_Features.backbone.block1[-1],
# self.teacher_Features.backbone.block2[-1],
# self.teacher_Features.backbone.block3[-1],
# self.teacher_Features.backbone.block4[-2],
#]
self._register_encoder_hooks(self.list_blocks)
#self._register_encoder_hooks(self.list_blocks_T)
self.decoder = DecoderUnet(1024)
def forward(self, data, val_step:bool = True):
self.feature_maps = []
if val_step:
data = data.float()
dataTeacher, dataStudent = data, data
else:
data = [d.float() for d in data]
dataTeacher, dataStudent = data[1:3], data
TeacherDino = self.teacher_Features(dataTeacher)
Student = self.student(dataStudent)
_temp = []
num_blocks = len(self.list_blocks)
NumChunks = len(dataStudent)
for i in range(num_blocks):
if not val_step:
S_GlogalViews = torch.cat(self.feature_maps[i + num_blocks].chunk(NumChunks)[1:3], dim=0)
else:
S_GlogalViews = self.feature_maps[i + num_blocks]
_temp.append(self.feature_maps[i] * S_GlogalViews)
self.feature_maps = _temp
if not val_step:
S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
self.feature_maps.append(cat_features)
reconstructed_image = self.decoder(self.feature_maps)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
pred = self.clasiffier(Cont_Net)
return pred, (Student, TeacherDino), reconstructed_image
class RARP_NVB_DINO_MultiTask_MultiLabel(RARP_NVB_DINO_MultiTask):
def __init__(
self,
TypeLoss=TypeLossFunction.BCEWithLogits,
momentum_teacher = 0.9995,
lr = 0.0001,
L1 = None,
L2 = 0,
std = None,
mean = None,
SoftAdptAlgo = 0,
SoftAdptBeta = 0.1,
Num_Lables = 2
):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta)
self.lossFN = torch.nn.BCEWithLogitsLoss()
self.clasiffier = torch.nn.Sequential(
torch.nn.Linear(1024, 128),
torch.nn.Dropout(0.4),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.Dropout(0.2),
torch.nn.SiLU(True),
torch.nn.Linear(8, Num_Lables)
)
self.train_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
self.val_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
self.test_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables)
self.f1ScoreTest = torchmetrics.F1Score('multilabel', num_labels=Num_Lables)
class RARP_NVB_DINO_MultiTask_v2(RARP_NVB_DINO_MultiTask):
def __init__(
self,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher = 0.9995,
lr = 0.0001,
L1 = None,
L2 = 0,
std = None,
mean = None
):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean)
self.train_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
self.val_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
self.test_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
self.f1ScoreTest = torchmetrics.F1Score('multiclass', num_classes=4)
self.lossFN = torch.nn.CrossEntropyLoss()
self.softAdapt = LossWeightedSoftAdapt(0.1) #NormalizedSoftAdapt(0.1)
self.clasiffier = torch.nn.Sequential(
torch.nn.Linear(1024, 512),
torch.nn.Dropout(0.2),
torch.nn.SiLU(True),
torch.nn.Linear(512, 256),
#torch.nn.Dropout(0.4),
torch.nn.SiLU(True),
torch.nn.Linear(256, 128),
#torch.nn.Dropout(0.4),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
#torch.nn.Dropout(0.2),
torch.nn.SiLU(True),
torch.nn.Linear(8, 4)
)
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
prediction, features, new_image = self(img, val_step)
StudentF, TeacherF = features
predicted_labels = torch.softmax(prediction,dim=1)
orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
label = torch.cat([label for _ in range(len(TeacherF))], dim=0) if not val_step else label
#DINO Loss
loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
#Clasificator
loss_HL = self.lossFN(predicted_labels, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
if not val_step:
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.clasiffier.parameters(): # aqui
loss_l1 += torch.sum(torch.abs(params))
loss_HL += self.Lambda_L1 * loss_l1
if self.Lambda_L2 > 0:
l2_reg = 0.0
for param in self.clasiffier.parameters():
l2_reg += torch.norm(param, 2) ** 2
loss_HL += self.Lambda_L2 * l2_reg
self.loss_history["loss_DINO"].append(loss_Dino.item())
self.loss_history["loss_Reconstruction"].append(loss_img.item())
self.loss_history["loss_Binary"].append(loss_HL.item())
loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
class RARP_NVB_RN50_VAN_V2 (RARP_NVB_Model):
# Define a hook function to capture the output
def _hook_fn(self, module, input, output):
self.last_conv_output = output
def __init__(
self,
PseudoEstimator:str = None,
threshold:float = 0.5,
InitWeight=None,
TypeLoss=TypeLossFunction.CrossEntropy,
PseudoLables:bool=True,
HParameter = {},
std: float = None,
mean: float = None,
**kwargs
) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
#self.RARP_RestNet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
self.RARP_RestNet50.model.fc = torch.nn.Identity()
for parms in self.RARP_RestNet50.model.parameters():
parms.requires_grad = False
self.RARP_RestNet50 = torch.nn.Sequential(*list(self.RARP_RestNet50.model.children())[:-2])
self.decoder = Decoder()
self.FeaturesLoss = FeatureAlignmentLoss()
self.ReconstructionLoss = torch.nn.MSELoss()
match (TypeError):
case TypeLossFunction.ContrastiveLoss:
self.lossFN = ContrastiveLoss()
case _:
pass
self.model = van.van_b2(pretrained = True, num_classes = 1)
self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04)
self.lr = getNVL(HParameter, "lr", 1.0E-4)
print(f"lr={self.lr}, L1={self.Lambda_L1}")
# Initialize a variable to store the output
self.last_conv_output = None
self.model.block4[-1].mlp.dwconv.register_forward_hook(self._hook_fn)
def forward(self, data):
dataTeacher, dataStudent, _ = data
dataStudent = dataStudent.float()
dataTeacher = dataTeacher.float()
feature_Teacher = self.RARP_RestNet50 (dataTeacher)
#feature_Student = self.model.forward_features(dataStudent)
pred = self.model(dataStudent)
feature_Student = self.last_conv_output
#feature_Teacher = torch.nn.functional.adaptive_avg_pool1d(feature_Teacher, feature_Student.size(-1)) #Re-size output vector to macth VAN
cat_features = (feature_Student + feature_Teacher) / 2
reconstructed_image = self.decoder(cat_features)
feature_Student = torch.nn.functional.adaptive_avg_pool2d(feature_Student, (1, 1)).flatten(1)
feature_Teacher = torch.nn.functional.adaptive_avg_pool2d(feature_Teacher, (1, 1)).flatten(1)
return pred, (feature_Student, feature_Teacher), reconstructed_image
def _shared_step(self, batch):
img, label = batch
label = label.float()
orignalImg = img[2].float()
prediction, features, new_image = self(img)
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
#cosine similarity
loss_cosine = self.FeaturesLoss(features[0], features[1])
#Clasificator
loss_hl = self.lossFN(prediction, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
loss = loss_cosine + loss_hl + loss_img
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.model.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
return loss, label, predicted_labels, (loss_cosine, loss_hl, loss_img, new_image)
def _denormalize(self, tensor:torch.Tensor):
# Move mean and std to the same device as the input tensor
mean = self.mean_IMG.to(tensor.device)
std = self.std_IMG.to(tensor.device)
return tensor * std + mean
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("train_loss_cosSim", losses[0], on_epoch=True, on_step=False)
self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("val_loss_cosSim", losses[0], on_epoch=True, on_step=False)
self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
if self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
return optimizer
class RARP_NVB_ResNet50_VAN(RARP_NVB_Model):
def __init__(
self,
PseudoEstimator:str = None,
threshold:float = 0.5,
InitWeight=None,
TypeLoss=TypeLossFunction.CrossEntropy,
PseudoLables:bool=True,
HParameter = {},
**kwargs
) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
self.threshold = threshold
self.PseudoLables = PseudoLables
match (TypeError):
case TypeLossFunction.ContrastiveLoss:
self.lossFN = ContrastiveLoss()
case _:
pass
for parms in self.RARP_RestNet50.model.parameters():
parms.requires_grad = False
self.model = van.van_b2(pretrained = True, num_classes = 1)
self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04)
self.lr = getNVL(HParameter, "lr", 1.0E-4)
self.W_Apha = getNVL(HParameter, "Alpha", 0.60)
self.W_Beta = 1 - getNVL(HParameter, "Alpha", 0.60)
self.thao_KD = getNVL(HParameter, "Thao", 5)
print(f"A={self.W_Apha}; B={self.W_Beta}, T={self.thao_KD}, lr={self.lr}, L1={self.Lambda_L1}")
#self.Lambda_L1 = None
#self.scheduler = True
#self.lr = 1.74E-2
def forward(self, data):
#dataTeacher, dataStudent = data
if isinstance(data, tuple):
dataTeacher, dataStudent = data
elif isinstance(data, torch.Tensor):
dataTeacher, dataStudent = (data, data)
RN50Pred = torch.sigmoid(self.RARP_RestNet50(dataTeacher).flatten())
PseudoLabels = (RN50Pred > self.threshold) * 1.0
dataStudent = dataStudent.float()
pred = self.model(dataStudent)
return pred, PseudoLabels, RN50Pred
def _shared_step(self, batch):
img, label = batch
if self.InitWeight is not None:
self.lossFN.weight = self.InitWeight[label]
label = label.float()
prediction, PseudoLabels, predictionRN50 = self(img) #.flatten()
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
#
# L[B] => BCEWithLogits
# L[C] => ContrastiveLoss
#Loss L[1] = L[B_HL] + L[B_PL]
#Loss L[2] = L[C_HL]
#Loss L[3] = L[C_HL] + L[C_PL]
#Loss L[4] = L[C_y_hat]
#Loss L[5] = L[KLD_PL] + L[B_PL]
#Loss L[6] = L[MSE_PL] + L[B_PL]
#Loss L[7] = L[B_HL] + L[BCE_PL]
#Loss L[8] = FocalLoss*L[JSD]
#Nuveo
#thao_KD = 5.0
#w_alpha, w_beta = (0.60, 0.40)
#L[7]:
softTeacher = torch.sigmoid(predictionRN50/self.thao_KD)
softStudent = torch.sigmoid(prediction/self.thao_KD)
loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
loss_hl = self.lossFN(prediction, label)
loss = self.W_Apha * loss_hl + self.W_Beta * loss_sl #/ (self.thao_KD ** 2)
#loss = w_alpha * self.lossFN(prediction, label) + w_beta * self.lossFN(prediction, PseudoLabels) #L[1]
#loss = w_alpha * (torch.nn.KLDivLoss()(softStudent, softTeacher) * (thao_KD ** 2)) + w_beta * self.lossFN(prediction, PseudoLabels) #L[5]
#loss = w_alpha * torch.nn.functional.mse_loss(softStudent, softTeacher) + w_beta * self.lossFN(prediction, PseudoLabels) #L[6]
#loss = self.lossFN(predictionRN50, predicted_labels, label) #L[2]
#loss = self.lossFN(predictionRN50, predicted_labels, label) + self.lossFN(predictionRN50, predicted_labels, PseudoLabels) #L[3]
#y_hat = (PseudoLabels != label) * 1
#loss = self.lossFN(predictionRN50, predicted_labels, y_hat) #L[4]
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.model.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
return loss, (PseudoLabels if self.PseudoLables else label), predicted_labels, (loss_hl, loss_sl)
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_lh", losses[0], on_step=False, on_epoch=True)
self.log("train_ld", losses[1], on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_lh", losses[0], on_step=False, on_epoch=True)
self.log("val_ld", losses[1], on_step=False, on_epoch=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels, _ = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
if self.scheduler:
scheduler = CosineAnnealingLR(
optimizer,
max_epochs=150,
warmup_epochs=8,
warmup_start_lr=0.001,
eta_min=0.0001
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
return optimizer
class RARP_NVB_SSL_RestNet50_Deep(RARP_NVB_ResNet50_VAN):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
InitWeight=None,
TypeLoss=TypeLossFunction.CrossEntropy,
PseudoLables: bool = True,
**kwargs
) -> None:
super().__init__(None, threshold, InitWeight, TypeLoss, PseudoLables, **kwargs)
self.RARP_RestNet50 = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
#self.model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
tempFC_ft = 2048
self.model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(in_features=tempFC_ft, out_features=128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
for parms in self.RARP_RestNet50.model.parameters():
parms.requires_grad = False
class RARP_NVB_MobileNetV2(RARP_NVB_Model):#class RARP_NVB_MobileNetV2(RARP_NVB_Model_BCEWithLogitsLoss):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_EfficientNetV2(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_Vit_b_16(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
super().__init__(InitWeight, TypeLoss)
self.model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
tempFC_ft = self.model.heads[0].in_features
self.model.heads[0] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_DenseNet169(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
super().__init__(InitWeight, TypeLoss)
self.model = torchvision.models.densenet169(weights=torchvision.models.DenseNet169_Weights.DEFAULT)
inFeatures = self.model.classifier.in_features
self.model.classifier = torch.nn.Linear(in_features=inFeatures, out_features=1)
class RARP_NVB_RestNet50_old(L.LightningModule):
def __init__(self, InitialWeigth = 1) -> None:
super().__init__()
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
self.lossFN = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([InitialWeigth]))
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def training_step(self, batch, batch_idx):
img, label = batch
label = label.float()
prediction = self(img)[:,0]
loss = self.lossFN(prediction, label)
self.log("Train Loss", loss)
self.log("Step Train Acc", self.train_acc(torch.sigmoid(prediction), label.int()))
return loss
def on_train_epoch_end(self):
self.log("Train Acc", self.train_acc.compute())
def validation_step(self, batch, batch_idx):
img, label = batch
label = label.float()
prediction = self(img)[:,0]
loss = self.lossFN(prediction, label)
self.log("Val Loss", loss)
self.log("Step Val Acc", self.val_acc(torch.sigmoid(prediction), label.int())) #train_acc
return loss
def on_validation_epoch_end(self):
self.log("Val Acc", self.val_acc.compute())
def configure_optimizers(self):
return [self.optimizer]