import math
from typing import Any, Union
import torch
import torchvision
import torchmetrics
import lightning as L
from enum import Enum
import timm
import van
import numpy as np
from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt
def js_divergence_sigmoid(p_logits, q_logits):
p_probs = torch.sigmoid(p_logits)
q_probs = torch.sigmoid(q_logits)
m = 0.5 * (p_probs + q_probs)
bce_p_m = torch.nn.functional.binary_cross_entropy(m, p_probs, reduction='none')
bce_q_m = torch.nn.functional.binary_cross_entropy(m, q_probs, reduction='none')
js_div = 0.5 * (bce_p_m + bce_q_m)
return js_div
def getNVL(obj:dict, key, default):
return default if obj.get(key) is None else obj.get(key)
class Decoder (torch.nn.Module):
def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], *args, **kwargs):
super().__init__(*args, **kwargs)
assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
self.Activation = torch.nn.ReLU#torch.nn.GELU
self.input_channels = input_channels
blocks = []
inCh = input_channels
blocks.append(torch.nn.Conv2d(inCh, inCh, kernel_size=3, stride=1, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(inCh))
blocks.append(self.Activation())
for i, outCh in enumerate(hidden_channels):
blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=4, stride=2, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(outCh))
blocks.append(self.Activation())
blocks.append(torch.nn.Conv2d(outCh, outCh, kernel_size=3, stride=1, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(outCh))
blocks.append(self.Activation())
#blocks.append(torch.nn.ConvTranspose2d(inCh, outCh, kernel_size=3, stride=2, padding=1, output_padding=1))
#blocks.append(torch.nn.BatchNorm2d(outCh))
#blocks.append(self.Activation())
inCh = outCh
blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=4, stride=2, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(output_channels))
blocks.append(self.Activation())
blocks.append(torch.nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=False))
blocks.append(torch.nn.BatchNorm2d(output_channels))
blocks.append(self.Activation())
#blocks.append(torch.nn.ConvTranspose2d(inCh, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
self.decoder = torch.nn.Sequential(*blocks)
def forward(self, x):
x = self.decoder(x)
return x
class DynamicDecoder(torch.nn.Module):
def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64]):
super(DynamicDecoder, self).__init__()
# Ensure the number of hidden channels matches the number of blocks
assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks."
layers = []
in_channels = input_channels
# Loop to create the decoder blocks
for out_channels in hidden_channels:
layers.append(torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
layers.append(torch.nn.BatchNorm2d(out_channels))
layers.append(torch.nn.ReLU(inplace=True))
in_channels = out_channels
# Final layer to get the output image
layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
layers.append(torch.nn.Sigmoid()) # To get pixel values between 0 and 1
# Combine all layers into a Sequential module
self.decoder = torch.nn.Sequential(*layers)
def forward(self, x):
return self.decoder(x)
class ModelsList(Enum):
RestNet50 = 1
DenseNet169 = 2
Efficientnet_b0 = 3
MobileNetV2 = 4
Inception3 = 5
ResNeXt_50_32x4d = 6
RestNet50_Droput = 7
class TypeLossFunction(Enum):
CrossEntropy = 0
BCEWithLogits = 1
HingeLoss = 2
FocalLoss = 3
ContrastiveLoss = 4
class FeatureAlignmentLoss(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, F_student, F_teacher):
fS = torch.nn.functional.normalize(F_student, p=2, dim=-1)
fT = torch.nn.functional.normalize(F_teacher, p=2, dim=-1)
cos_Sim = torch.sum(fS * fT, dim=-1)
loss = 1 - cos_Sim
return loss.mean()
class CosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
max_epochs: int,
warmup_start_lr: float = 0.00001,
eta_min: float = 0.00001,
last_epoch: int = -1,
):
"""
Args:
optimizer (torch.optim.Optimizer):
最適化手法インスタンス
warmup_epochs (int):
linear warmupを行うepoch数
max_epochs (int):
cosine曲線の終了に用いる 学習のepoch数
warmup_start_lr (float):
linear warmup 0 epoch目の学習率
eta_min (float):
cosine曲線の下限
last_epoch (int):
cosine曲線の位相オフセット
学習率をmax_epochsに至るまでコサイン曲線に沿ってスケジュールする
epoch 0からwarmup_epochsまでの学習曲線は線形warmupがかかる
https://pytorch-lightning-bolts.readthedocs.io/en/stable/schedulers/warmup_cosine_annealing.html
"""
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.warmup_start_lr = warmup_start_lr
self.eta_min = eta_min
super().__init__(optimizer, last_epoch, verbose=False)
return None
def get_lr(self):
if self.last_epoch == 0:
return [self.warmup_start_lr] * len(self.base_lrs)
if self.last_epoch < self.warmup_epochs:
return [
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
if self.last_epoch == self.warmup_epochs:
return self.base_lrs
if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
return [
group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
/ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)))
* (group["lr"] - self.eta_min)
+ self.eta_min
for group in self.optimizer.param_groups
]
class ContrastiveLoss(torch.nn.Module):
def __init__(self, margin=1.0, distance:int = 0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.TypeDistance = distance
def forward(self, output1, output2, label):
# Calcula la distancia euclidiana entre las dos salidas
Distance = torch.nn.functional.pairwise_distance(output1, output2)
# Calcula la pérdida contrastiva
loss_contrastive = torch.mean(
(1 - label) * torch.pow(Distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - Distance, min=0.0), 2)
)
return loss_contrastive
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self.hook()
def hook(self):
def forward_hook(module, Input, Output):
self.activations = Output
def backward_hook(module, grad_in, grad_out):
self.gradients = grad_out[0]
self.target_layer.register_forward_hook(forward_hook)
self.target_layer.register_backward_hook(backward_hook)
def generate_cam(self, data:torch.Tensor, target_class):
self.model.zero_grad()
data.requires_grad= True
output = self.model(data).flatten()
loss = torch.nn.functional.binary_cross_entropy_with_logits(output, target_class)
loss.backward()
pooled_Grad = torch.mean(self.gradients, dim=[0, 2, 3])
for i in range (pooled_Grad.size(0)):
self.activations[:, i, :, :] *= pooled_Grad[i]
cam = torch.mean(self.activations, dim=1).squeeze()
cam = np.maximum(cam.detach().numpy(), 0)
cam = (cam - cam.min()) / (cam.max() - cam.min())
return cam
class RARP_NVB_ResNet50_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.resnet50()
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
pred = self.model.fc(Cont_Net)
return pred, featureMap
class RARP_NVB_VAN_CAM(L.LightningModule): #TODO
def __init__(self) -> None:
super().__init__()
self.model = van.van_b2(pretrained = True)
tempFC_ft = self.model.head.in_features
self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data, label:torch.Tensor):
cams = GradCAM(self.model, self.model.block4[-1])
featureMap = cams.generate_cam(data, label)
pred = self.model(featureMap)
return pred, featureMap
class RARP_NVB_ResNet18_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.resnet18()
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1))
Cont_Net = torch.flatten(Cont_Net)
pred = self.model.fc(Cont_Net)
return pred, featureMap
class RARP_NVB_MobileNetV2_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.mobilenet_v2()
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = self.model.features
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
pred = self.model.classifier(Cont_Net)
return pred, featureMap
class RARP_NVB_EfficientNetV2_CAM(L.LightningModule):
def __init__(self) -> None:
super().__init__()
self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.feature_map = self.model.features
def forward(self, data):
featureMap = self.feature_map(data)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(featureMap, (1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
pred = self.model.classifier(Cont_Net)
return pred, featureMap
class RARP_NVB_Model_BCEWithLogitsLoss(L.LightningModule):
def __init__(self, x=None) -> None:
super().__init__()
self.model = None
self.lossFN = torch.nn.BCEWithLogitsLoss() #pos_weight=torch.tensor([2.73])
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1Score = torchmetrics.F1Score('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
def forward(self, data):
data = data.float()
return self.model(data)
def _shared_step(self, batch):
img, label = batch
label = label.float()
pred = self.forward(img).flatten()
loss = self.lossFN(pred, label)
predicted_labels = torch.sigmoid(pred)
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss)
self.val_acc(predicted_labels, true_labels)
self.f1Score(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc(predicted_labels, true_labels)
self.f1ScoreTest(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
return [optimizer]
class RARP_NVB_FOCAL_loss(torch.nn.Module):
def __init__(self, alpha: float = 0.25, gamma: float = 2, reduction: str = "mean"):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torchvision.ops.focal_loss.sigmoid_focal_loss(input, target, self.alpha, self.gamma, self.reduction)
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
# Convert targets to one-hot encoding
targets = torch.nn.functional.one_hot(targets, num_classes=inputs.size(1)).float()
# Compute softmax over the inputs
probs = torch.nn.functional.softmax(inputs, dim=1)
log_probs = torch.nn.functional.log_softmax(inputs, dim=1)
# Compute the focal loss components
focal_weight = (1 - probs) ** self.gamma
loss = -self.alpha * focal_weight * targets * log_probs
# Apply reduction method
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
#TODO
class RARP_NVB_MultiClassModel(L.LightningModule):
def __init__(self,
InitWeight = torch.tensor([1,1]),
schedulerLR:bool=False,
lr:float = 1e-4,
Model:torch.nn.Module = None,
Num_Classes:int = 2,
L1:float = 1.31E-04,
L2:float = 0
) -> None:
super().__init__()
self.model = Model
self.lossFN = FocalLoss() #torch.nn.CrossEntropyLoss()
self.InitWeight = InitWeight
self.scheduler = schedulerLR
self.lr = lr
self.Lambda_L1 = L1
self.Lambda_L2 = L2
self.num_classes = Num_Classes
self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
self.test_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes)
self.f1ScoreTest = torchmetrics.F1Score("multiclass", num_classes=Num_Classes)
self.val_loop = False
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def _shared_step(self, batch):
img, label = batch
prediction = self(img)
predicted_labels = torch.softmax(prediction, dim=1)
loss = self.lossFN(prediction, label)
if self.Lambda_L1 is not None and not self.val_loop:
loss_l1 = 0
for params in self.model.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
self.val_loop = False
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
self.val_loop = True
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
self.val_loop = True
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
if self.scheduler:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
else:
return [optimizer]
class RARP_NVB_Model(L.LightningModule):
def __init__(self,
InitWeight = torch.tensor([1,1]),
typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy,
schedulerLR:bool=True,
lr:float = 1e-4
) -> None:
super().__init__()
self.model = None
self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
self.InitWeight = InitWeight
self.scheduler = schedulerLR
self.lr = lr
self.Lambda_L1 = None #1.31E-04 #
self.Lambda_L2 = 0
print (f"LR= {self.lr}, L1= {self.Lambda_L1}")
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
#self.f1Score = torchmetrics.F1Score('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def _shared_step(self, batch):
img, label = batch
if self.InitWeight is not None:
self.lossFN.weight = self.InitWeight[label]
label = label.float()
prediction = self(img)[:,0] #.flatten()
loss = self.lossFN(prediction, label)
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
predicted_labels = torch.sigmoid(prediction)
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
#self.f1Score.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("hp_metric", self.val_acc, on_step=False, on_epoch=True)
#self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
if self.scheduler:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, eta_min=1e-8, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
else:
return optimizer
class RARP_Ensemble(L.LightningModule):
def __init__(self, ListModels,
InitWeight = torch.tensor([1,1]),
typeLossFN:TypeLossFunction = TypeLossFunction.CrossEntropy,
schedulerLR:bool=False,
lr:float = 1e-4) -> None:
super().__init__()
self.ListModels = ListModels
#for m in self.ListModels:
# m.freeze()
input_p = len(self.ListModels)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(in_features=input_p, out_features=128),
torch.nn.SiLU(),
#torch.nn.Dropout(0.1), #0.5
torch.nn.Linear(128, 1),
torch.nn.Sigmoid()
)
self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if typeLossFN == TypeLossFunction.CrossEntropy else torch.nn.BCELoss(reduction="sum")
self.InitWeight = InitWeight
self.scheduler = schedulerLR
self.lr = lr
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1Score = torchmetrics.F1Score('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
def forward(self, data):
data = data.float()
p = [m(data) for m in self.ListModels]
p = torch.cat(p, dim=1)
x = self.classifier(p)
return x
def _shared_step(self, batch):
img, label = batch
if self.InitWeight is not None:
self.lossFN.weight = self.InitWeight[label]
label = label.float()
prediction = self(img)[:,0] #.flatten()
loss = self.lossFN(prediction, label)
predicted_labels = prediction
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.f1Score.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_f1", self.f1Score, on_epoch=True, on_step=False, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
if self.scheduler:
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1, verbose=True, factor=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 6, 1e-8, verbose=True)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
else:
return [optimizer]
class RARP_NVB_ResNet18(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_DaVit(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = timm.create_model("davit_small.msft_in1k", pretrained=True, num_classes=1)
class RARP_NVB_EfficientNetV2_Deep(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
self.model.classifier.append(torch.nn.SiLU(True))
self.model.classifier.append(torch.nn.Linear(128, 8))
self.model.classifier.append(torch.nn.SiLU(True))
self.model.classifier.append(torch.nn.Linear(8, 1))
class RARP_NVB_ResNet50_Deep(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(in_features=tempFC_ft, out_features=128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
def forward(self, img):
img = img.float()
pred = self.model(img)
return pred
class RARP_NVB_ResNet50_Deep_OPTuna(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, output_dims=[128, 8], dropuot=0.2, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
layers = []
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
inputDim = self.model.fc.in_features
for outputDim in output_dims:
layers.append(torch.nn.Linear(inputDim, outputDim))
layers.append(torch.nn.SiLU(True))
layers.append(torch.nn.Dropout(dropuot))
inputDim = outputDim
layers.append(torch.nn.Linear(inputDim, 1))
self.model.fc = torch.nn.Sequential(*layers)
def forward(self, img):
img = img.float()
pred = self.model(img)
return pred
class RARP_NVB_ResNet50_V2(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(in_features=tempFC_ft, out_features=128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8)
)
self.model.fc2 = torch.nn.Linear(8 + InputNeurons, 1)
def forward(self, data):
img, extra = data
img = img.float()
x = torch.nn.functional.silu(self.model(img), True)
extradata = torch.concat((x, extra), dim=1)
pred = self.model.fc2(extradata)
return pred
class RARP_NVB_ResNet50_V3(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=128)
self.model.extraFC = torch.nn.Linear(InputNeurons, 128)
self.model.fc2 = torch.nn.Linear(256, 1)
def forward(self, data):
img, extra = data
img = img.float()
x = torch.nn.functional.silu(self.model(img), True)
y = torch.nn.functional.silu(self.model.extraFC(extra), True)
x = torch.concat((x, y), dim=1)
pred = self.model.fc2(x)
return pred
class RARP_NVB_ResNet50_V1(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, InputNeurons:int = 8, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features + InputNeurons
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.Decoder = torch.nn.Sequential(*list(self.model.children())[:-2])
def forward(self, data):
img, extra = data
img = img.float()
featureMap = self.Decoder(img)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(input=featureMap, output_size=(1, 1))
Cont_Net = torch.flatten(Cont_Net, 1)
Cont_Net = torch.concat((Cont_Net, extra), dim=1)
pred = self.model.fc(Cont_Net)
return pred
class RARP_NVB_VAN(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = van.van_b2(pretrained = True)
tempFC_ft = self.model.head.in_features
self.model.head = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_ResNet50(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_MLP(torch.nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=512, bottleneck=256, n_layers=3, norm_last_layer=False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.activationFN = torch.nn.GELU()
if n_layers == 1:
self.mlp = torch.nn.Linear(in_dim, bottleneck)
else:
layers = [torch.nn.Linear(in_dim, hidden_dim)]
layers.append(torch.nn.Dropout(0.30))
layers.append(self.activationFN)
for _ in range(n_layers - 2):
layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
layers.append(self.activationFN)
layers.append(torch.nn.Linear(hidden_dim, bottleneck))
self.mlp = torch.nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = torch.nn.utils.weight_norm(
torch.nn.Linear(bottleneck, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = torch.nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
class RARP_NVB_DINO_Wrapper(torch.nn.Module):
def __init__(self, backbone, new_head, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.backbone = backbone
self.head = new_head
def forward(self, x):
if isinstance(x, list):
n_crops = len(x)
concatCrops = torch.cat(x, dim=0)
else:
concatCrops = x
n_crops = 1
embedding = self.backbone(concatCrops)
logitis = self.head(embedding)
chunks = logitis.chunk(n_crops)
return chunks
class RARP_NVB_DINO_Loss(torch.nn.Module):
def __init__(self, out_dim:int, teacher_Thao:float = 0.04, student_Thao:float = 0.1, center_momentum:float = 0.9, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.S_Thao = student_Thao
self.T_Thao = teacher_Thao
self.C_Momentum = center_momentum
self.register_buffer("center", torch.zeros(1, out_dim))
def forward(self, s_Output, t_Output):
sTemp = [s / self.S_Thao for s in s_Output]
tTemp = [(t - self.center) / self.T_Thao for t in t_Output]
studentSM = [torch.nn.functional.log_softmax(s, dim=-1) for s in sTemp]
teacherSM = [torch.nn.functional.softmax(t, dim=-1).detach() for t in tTemp]
total_loss = 0
n_loss_terms = 0
for t_ix, t in enumerate(teacherSM):
for s_ix, s in enumerate(studentSM):
if (t_ix == s_ix) and (len(teacherSM) > 1):
continue
loss = torch.sum(-t * s, dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
self.update_center(t_Output)
return total_loss
@torch.no_grad()
def update_center(self, t_output):
b = torch.cat(t_output).mean(dim=0, keepdim=True)
self.center = self.center * self.C_Momentum + b * (1 - self.C_Momentum)
class RARP_NVB_DINO_RestNet50_Deep(L.LightningModule):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
L1:float = None,
L2:float = 0,
) -> None:
super().__init__()
self.lr = lr
self.Lambda_L1 = L1
self.Lambda_L2 = L2
self.threshold = threshold
self.momentum_teacher = momentum_teacher
self.out_dim = 512
self.in_dim = 2048
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
self.teacher_Labels = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
self.student = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) #torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
self.teacher_Features = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
self.student.fc = torch.nn.Identity()
self.teacher_Features.fc = torch.nn.Identity()
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
for parms in self.teacher_Labels.model.parameters():
parms.requires_grad = False
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, 0.04, 0.1, momentum_teacher)
self.lossFN_KD = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
#self.lossFH_KLDiv = torch.nn.KLDivLoss(reduction="batchmean")
self.clasiffier = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(self.out_dim, 128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
def forward(self, data, val_step:bool = True):
if val_step:
data = data.float()
dataClassificator, dataTeacher, dataStudent = data, data, data
else:
data = [d.float() for d in data]
dataClassificator, dataTeacher, dataStudent = data[0], data[1:3], data
TeacherDino = self.teacher_Features(dataTeacher)
TeacherLabels = self.teacher_Labels(dataClassificator)
Student = self.student(dataStudent)
# es se evaluan todas las salidas del estuidaitne
#if isinstance(dataStudent, list):
# #index = np.random.randint(0, len(dataStudent))
# temp = self.student(dataStudent)
# CatS_Classifier = torch.cat(temp, dim=0)
# meanS_Classifier = torch.zeros(self.in_dim)
# for dataS in temp:
# meanS_Classifier += dataS
# #S_Classifier = self.student(dataStudent[index])
# S_Classifier = meanS_Classifier / len(dataStudent)
#else:
# S_Classifier = Student
Cont_Net = torch.cat(Student, dim=0)
pred = self.clasiffier(Cont_Net)
if not val_step:
TeacherLabels = [self.teacher_Labels(dataClassificator) for _ in range(len(dataStudent))]
TeacherLabels = torch.cat(TeacherLabels, dim=0)
TeacherLabelsPred = torch.sigmoid(TeacherLabels.flatten())
PseudoLabels = (TeacherLabelsPred > self.threshold) * 1.0
return (pred.flatten(), PseudoLabels, TeacherLabels.flatten()), (TeacherDino, Student)
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
if not val_step:
label = torch.cat([label for _ in range(len(img))], dim=0)
label = label.float()
KD_Prediction, DINO_Loss = self(img, val_step)
TeacherF, StudentF = DINO_Loss
prediction, PseudoLabels, teacherOutputs = KD_Prediction
predicted_labels = torch.sigmoid(prediction)
##verstion 1
W_Alpha, W_Beta = (1, 0.5)#(1, 0.5)
loss = W_Alpha * self.lossFN_KD(prediction, PseudoLabels) + W_Beta * self.lossFN_KD(prediction, label)
#version 2
#thao_KD = 1#5.0
#W_Alpha, W_Beta = (0.6, 0.4)
#softTeacher = torch.sigmoid(teacherOutputs/thao_KD)
#softStudent = torch.sigmoid(prediction/thao_KD)
#loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
#loss_hl = self.lossFN_KD(prediction, label)
#loss = W_Alpha * loss_hl + W_Beta * loss_sl
#loss = W_Alpha * self.lossFN_KD(prediction, label) + W_Beta * (self.lossFH_KLDiv(softStudent, softTeacher) * (thao_KD ** 2))
loss += (self.lossFN_DINO(StudentF, TeacherF) if not val_step else 0)
if not val_step:
self.logger.experiment.add_histogram ("Teacher", TeacherF[0])
self.logger.experiment.add_histogram ("Student", StudentF[1])
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.student.parameters(): # aqui
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
#return loss, PseudoLabels, predicted_labels
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
teacher_ps.data.mul_(self.momentum_teacher)
teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
return [optimizer]
class RARP_NVB_DINO_VAN(RARP_NVB_DINO_RestNet50_Deep):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher: float = 0.9995,
lr: float = 0.0001,
L1: float = None,
L2: float = 0
) -> None:
super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
self.in_dim = 512
self.student = van.van_b2(pretrained = True, num_classes = -1)
self.teacher_Features = van.van_b2(pretrained = True, num_classes = -1)
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
class RARP_NVB_DINO_ViT(RARP_NVB_DINO_RestNet50_Deep):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher: float = 0.9995,
lr: float = 0.0001,
L1: float = None,
L2: float = 0
) -> None:
super().__init__(PseudoEstimator, threshold, TypeLoss, momentum_teacher, lr, L1, L2)
self.in_dim = 768
self.student = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
self.teacher_Features = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
self.student.heads = torch.nn.Identity()
self.teacher_Features.heads = torch.nn.Identity()
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, 1024)
)
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
class RARP_NVB_DINO_MultiTask(L.LightningModule):
# Define a hook function to capture the output
def _hook_fn_Student(self, module, input, output):
self.last_conv_output_S = output
def _hook_fn_Teacher(self, module, input, output):
self.last_conv_output_T = output
def __init__(
self,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher:float = 0.9995,
lr:float = 1e-4,
L1:float = None,
L2:float = 0,
std: float = None,
mean: float = None,
SoftAdptAlgo:int = 0,
SoftAdptBeta:float = 0.1
) -> None:
super().__init__()
self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
self.lr = lr
self.Lambda_L1 = L1
self.Lambda_L2 = L2
self.momentum_teacher = momentum_teacher
self.out_dim = 1024
self.in_dim = 512
self.weights = torch.tensor([1,1,1])
self.softAdapt = NormalizedSoftAdapt(SoftAdptBeta) if SoftAdptAlgo == 1 else LossWeightedSoftAdapt(SoftAdptBeta)
self.loss_history = {
'loss_DINO': [],
'loss_Reconstruction': [],
'loss_Binary': [],
}
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
self.student = van.van_b2(pretrained = True, num_classes = 0)
self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0)
self.decoder = DynamicDecoder(input_channels=1024) #Decoder(input_channels=1024, hidden_channels=[512, 256, 128, 64])
self.student = RARP_NVB_DINO_Wrapper(
self.student,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
)
self.teacher_Features = RARP_NVB_DINO_Wrapper(
self.teacher_Features,
RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2)
)
self.teacher_Features.load_state_dict(self.student.state_dict())
for parms in self.teacher_Features.parameters():
parms.requires_grad = False
self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, 0.04, 0.1, momentum_teacher)
self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss()
self.ReconstructionLoss = torch.nn.MSELoss()
self.last_conv_output_T = None
self.last_conv_output_S = None
self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher)
self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student)
self.clasiffier = torch.nn.Sequential(
torch.nn.Linear(1024, 128),
torch.nn.Dropout(0.4),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.Dropout(0.2),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
print(f"lr={self.lr}, L1={self.Lambda_L1}")
def _denormalize(self, tensor:torch.Tensor):
# Move mean and std to the same device as the input tensor
mean = self.mean_IMG.to(tensor.device)
std = self.std_IMG.to(tensor.device)
return tensor * std + mean
def _calc_weights(self, log_weights:bool = True):
self.weights = self.softAdapt.get_component_weights(
torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]),
torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_Reconstruction"]),
torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]),
verbose=False
)
if log_weights:
self.log("W_loss_img", self.weights[1], on_epoch=True, on_step=False)
self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False)
self.log("W_loss_GT", self.weights[2], on_epoch=True, on_step=False)
self.loss_history = {
'loss_DINO': [],
'loss_Reconstruction': [],
'loss_Binary': [],
}
def forward(self, data, val_step:bool = True):
if val_step:
data = data.float()
dataTeacher, dataStudent = data, data
else:
data = [d.float() for d in data]
dataTeacher, dataStudent = data[1:3], data
TeacherDino = self.teacher_Features(dataTeacher)
Student = self.student(dataStudent)
if not val_step:
NumChunks = len(dataStudent)
S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3]
self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0)
cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1)
reconstructed_image = self.decoder(cat_features)
Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1)
pred = self.clasiffier(Cont_Net)
return pred, (Student, TeacherDino), reconstructed_image
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
prediction, features, new_image = self(img, val_step)
StudentF, TeacherF = features
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float()
#DINO Loss
loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
#Clasificator
loss_HL = self.lossFN(prediction, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
if not val_step:
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.clasiffier.parameters(): # aqui
loss_l1 += torch.sum(torch.abs(params))
loss_HL += self.Lambda_L1 * loss_l1
if self.Lambda_L2 > 0:
l2_reg = 0.0
for param in self.clasiffier.parameters():
l2_reg += torch.norm(param, 2) ** 2
loss_HL += self.Lambda_L2 * l2_reg
self.loss_history["loss_DINO"].append(loss_Dino.item())
self.loss_history["loss_Reconstruction"].append(loss_img.item())
self.loss_history["loss_Binary"].append(loss_HL.item())
loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
def on_train_epoch_start(self):
if self.current_epoch % 2 == 0 and self.current_epoch != 0:
self._calc_weights()
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("train_loss_DINO", losses[0], on_epoch=True, on_step=False)
self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
return loss
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()):
teacher_ps.data.mul_(self.momentum_teacher)
teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data)
#self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center)
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("val_loss_DINO", losses[0], on_epoch=True, on_step=False)
self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels, losses = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
if self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) #, weight_decay=self.Lambda_L2
return [optimizer]
class RARP_NVB_DINO_MultiTask_v2(RARP_NVB_DINO_MultiTask):
def __init__(
self,
TypeLoss=TypeLossFunction.CrossEntropy,
momentum_teacher = 0.9995,
lr = 0.0001,
L1 = None,
L2 = 0,
std = None,
mean = None
):
super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean)
self.train_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
self.val_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
self.test_acc = torchmetrics.Accuracy('multiclass', num_classes=4)
self.f1ScoreTest = torchmetrics.F1Score('multiclass', num_classes=4)
self.lossFN = torch.nn.CrossEntropyLoss()
self.softAdapt = LossWeightedSoftAdapt(0.1) #NormalizedSoftAdapt(0.1)
self.clasiffier = torch.nn.Sequential(
torch.nn.Linear(1024, 512),
torch.nn.Dropout(0.2),
torch.nn.SiLU(True),
torch.nn.Linear(512, 256),
#torch.nn.Dropout(0.4),
torch.nn.SiLU(True),
torch.nn.Linear(256, 128),
#torch.nn.Dropout(0.4),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
#torch.nn.Dropout(0.2),
torch.nn.SiLU(True),
torch.nn.Linear(8, 4)
)
def _shared_step(self, batch, val_step:bool = False):
img, label = batch
prediction, features, new_image = self(img, val_step)
StudentF, TeacherF = features
predicted_labels = torch.softmax(prediction,dim=1)
orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float()
label = torch.cat([label for _ in range(len(TeacherF))], dim=0) if not val_step else label
#DINO Loss
loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32)
#Clasificator
loss_HL = self.lossFN(predicted_labels, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
if not val_step:
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.clasiffier.parameters(): # aqui
loss_l1 += torch.sum(torch.abs(params))
loss_HL += self.Lambda_L1 * loss_l1
if self.Lambda_L2 > 0:
l2_reg = 0.0
for param in self.clasiffier.parameters():
l2_reg += torch.norm(param, 2) ** 2
loss_HL += self.Lambda_L2 * l2_reg
self.loss_history["loss_DINO"].append(loss_Dino.item())
self.loss_history["loss_Reconstruction"].append(loss_img.item())
self.loss_history["loss_Binary"].append(loss_HL.item())
loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL
return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image)
class RARP_NVB_RN50_VAN_V2 (RARP_NVB_Model):
# Define a hook function to capture the output
def _hook_fn(self, module, input, output):
self.last_conv_output = output
def __init__(
self,
PseudoEstimator:str = None,
threshold:float = 0.5,
InitWeight=None,
TypeLoss=TypeLossFunction.CrossEntropy,
PseudoLables:bool=True,
HParameter = {},
std: float = None,
mean: float = None,
**kwargs
) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None
self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None
#self.RARP_RestNet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
self.RARP_RestNet50.model.fc = torch.nn.Identity()
for parms in self.RARP_RestNet50.model.parameters():
parms.requires_grad = False
self.RARP_RestNet50 = torch.nn.Sequential(*list(self.RARP_RestNet50.model.children())[:-2])
self.decoder = Decoder()
self.FeaturesLoss = FeatureAlignmentLoss()
self.ReconstructionLoss = torch.nn.MSELoss()
match (TypeError):
case TypeLossFunction.ContrastiveLoss:
self.lossFN = ContrastiveLoss()
case _:
pass
self.model = van.van_b2(pretrained = True, num_classes = 1)
self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04)
self.lr = getNVL(HParameter, "lr", 1.0E-4)
print(f"lr={self.lr}, L1={self.Lambda_L1}")
# Initialize a variable to store the output
self.last_conv_output = None
self.model.block4[-1].mlp.dwconv.register_forward_hook(self._hook_fn)
def forward(self, data):
dataTeacher, dataStudent, _ = data
dataStudent = dataStudent.float()
dataTeacher = dataTeacher.float()
feature_Teacher = self.RARP_RestNet50 (dataTeacher)
#feature_Student = self.model.forward_features(dataStudent)
pred = self.model(dataStudent)
feature_Student = self.last_conv_output
#feature_Teacher = torch.nn.functional.adaptive_avg_pool1d(feature_Teacher, feature_Student.size(-1)) #Re-size output vector to macth VAN
cat_features = (feature_Student + feature_Teacher) / 2
reconstructed_image = self.decoder(cat_features)
feature_Student = torch.nn.functional.adaptive_avg_pool2d(feature_Student, (1, 1)).flatten(1)
feature_Teacher = torch.nn.functional.adaptive_avg_pool2d(feature_Teacher, (1, 1)).flatten(1)
return pred, (feature_Student, feature_Teacher), reconstructed_image
def _shared_step(self, batch):
img, label = batch
label = label.float()
orignalImg = img[2].float()
prediction, features, new_image = self(img)
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
#cosine similarity
loss_cosine = self.FeaturesLoss(features[0], features[1])
#Clasificator
loss_hl = self.lossFN(prediction, label)
#Reconstruction
loss_img = self.ReconstructionLoss(new_image, orignalImg)
loss_img = loss_img.float()
loss = loss_cosine + loss_hl + loss_img
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.model.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
return loss, label, predicted_labels, (loss_cosine, loss_hl, loss_img, new_image)
def _denormalize(self, tensor:torch.Tensor):
# Move mean and std to the same device as the input tensor
mean = self.mean_IMG.to(tensor.device)
std = self.std_IMG.to(tensor.device)
return tensor * std + mean
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("train_loss_cosSim", losses[0], on_epoch=True, on_step=False)
self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False)
if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images', grid, self.global_step)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_loss_img", losses[2], on_epoch=True, on_step=False)
self.log("val_loss_cosSim", losses[0], on_epoch=True, on_step=False)
self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False)
def test_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
if self.mean_IMG is not None and self.std_IMG is not None:
imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1)
imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :]
grid = torchvision.utils.make_grid(imgReconstruction)
self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
return optimizer
class RARP_NVB_ResNet50_VAN(RARP_NVB_Model):
def __init__(
self,
PseudoEstimator:str = None,
threshold:float = 0.5,
InitWeight=None,
TypeLoss=TypeLossFunction.CrossEntropy,
PseudoLables:bool=True,
HParameter = {},
**kwargs
) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.RARP_RestNet50 = RARP_NVB_ResNet50.load_from_checkpoint(PseudoEstimator) if PseudoEstimator is not None else RARP_NVB_ResNet50()
self.threshold = threshold
self.PseudoLables = PseudoLables
match (TypeError):
case TypeLossFunction.ContrastiveLoss:
self.lossFN = ContrastiveLoss()
case _:
pass
for parms in self.RARP_RestNet50.model.parameters():
parms.requires_grad = False
self.model = van.van_b2(pretrained = True, num_classes = 1)
self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04)
self.lr = getNVL(HParameter, "lr", 1.0E-4)
self.W_Apha = getNVL(HParameter, "Alpha", 0.60)
self.W_Beta = 1 - getNVL(HParameter, "Alpha", 0.60)
self.thao_KD = getNVL(HParameter, "Thao", 5)
print(f"A={self.W_Apha}; B={self.W_Beta}, T={self.thao_KD}, lr={self.lr}, L1={self.Lambda_L1}")
#self.Lambda_L1 = None
#self.scheduler = True
#self.lr = 1.74E-2
def forward(self, data):
#dataTeacher, dataStudent = data
if isinstance(data, tuple):
dataTeacher, dataStudent = data
elif isinstance(data, torch.Tensor):
dataTeacher, dataStudent = (data, data)
RN50Pred = torch.sigmoid(self.RARP_RestNet50(dataTeacher).flatten())
PseudoLabels = (RN50Pred > self.threshold) * 1.0
dataStudent = dataStudent.float()
pred = self.model(dataStudent)
return pred, PseudoLabels, RN50Pred
def _shared_step(self, batch):
img, label = batch
if self.InitWeight is not None:
self.lossFN.weight = self.InitWeight[label]
label = label.float()
prediction, PseudoLabels, predictionRN50 = self(img) #.flatten()
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
#
# L[B] => BCEWithLogits
# L[C] => ContrastiveLoss
#Loss L[1] = L[B_HL] + L[B_PL]
#Loss L[2] = L[C_HL]
#Loss L[3] = L[C_HL] + L[C_PL]
#Loss L[4] = L[C_y_hat]
#Loss L[5] = L[KLD_PL] + L[B_PL]
#Loss L[6] = L[MSE_PL] + L[B_PL]
#Loss L[7] = L[B_HL] + L[BCE_PL]
#Loss L[8] = FocalLoss*L[JSD]
#Nuveo
#thao_KD = 5.0
#w_alpha, w_beta = (0.60, 0.40)
#L[7]:
softTeacher = torch.sigmoid(predictionRN50/self.thao_KD)
softStudent = torch.sigmoid(prediction/self.thao_KD)
loss_sl = torch.nn.functional.binary_cross_entropy(softStudent, softTeacher)
loss_hl = self.lossFN(prediction, label)
loss = self.W_Apha * loss_hl + self.W_Beta * loss_sl #/ (self.thao_KD ** 2)
#loss = w_alpha * self.lossFN(prediction, label) + w_beta * self.lossFN(prediction, PseudoLabels) #L[1]
#loss = w_alpha * (torch.nn.KLDivLoss()(softStudent, softTeacher) * (thao_KD ** 2)) + w_beta * self.lossFN(prediction, PseudoLabels) #L[5]
#loss = w_alpha * torch.nn.functional.mse_loss(softStudent, softTeacher) + w_beta * self.lossFN(prediction, PseudoLabels) #L[6]
#loss = self.lossFN(predictionRN50, predicted_labels, label) #L[2]
#loss = self.lossFN(predictionRN50, predicted_labels, label) + self.lossFN(predictionRN50, predicted_labels, PseudoLabels) #L[3]
#y_hat = (PseudoLabels != label) * 1
#loss = self.lossFN(predictionRN50, predicted_labels, y_hat) #L[4]
if self.Lambda_L1 is not None:
loss_l1 = 0
for params in self.model.parameters():
loss_l1 += torch.sum(torch.abs(params))
loss += self.Lambda_L1 * loss_l1
return loss, (PseudoLabels if self.PseudoLables else label), predicted_labels, (loss_hl, loss_sl)
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(predicted_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
self.log("train_lh", losses[0], on_step=False, on_epoch=True)
self.log("train_ld", losses[1], on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels, losses = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(predicted_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_lh", losses[0], on_step=False, on_epoch=True)
self.log("val_ld", losses[1], on_step=False, on_epoch=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels, _ = self._shared_step(batch)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.Lambda_L2)
if self.scheduler:
scheduler = CosineAnnealingLR(
optimizer,
max_epochs=150,
warmup_epochs=8,
warmup_start_lr=0.001,
eta_min=0.0001
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
#"monitor": "val_loss",
}
}
return optimizer
class RARP_NVB_SSL_RestNet50_Deep(RARP_NVB_ResNet50_VAN):
def __init__(
self,
PseudoEstimator: str = None,
threshold: float = 0.5,
InitWeight=None,
TypeLoss=TypeLossFunction.CrossEntropy,
PseudoLables: bool = True,
**kwargs
) -> None:
super().__init__(None, threshold, InitWeight, TypeLoss, PseudoLables, **kwargs)
self.RARP_RestNet50 = RARP_NVB_ResNet50_Deep.load_from_checkpoint(PseudoEstimator, strict=False) if PseudoEstimator is not None else RARP_NVB_ResNet50_Deep()
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
#self.model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
tempFC_ft = 2048
self.model.fc = torch.nn.Sequential(
torch.nn.Dropout(0.2),
torch.nn.Linear(in_features=tempFC_ft, out_features=128),
torch.nn.SiLU(True),
torch.nn.Linear(128, 8),
torch.nn.SiLU(True),
torch.nn.Linear(8, 1)
)
for parms in self.RARP_RestNet50.model.parameters():
parms.requires_grad = False
class RARP_NVB_MobileNetV2(RARP_NVB_Model):#class RARP_NVB_MobileNetV2(RARP_NVB_Model_BCEWithLogitsLoss):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_EfficientNetV2(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy, **kwargs) -> None:
super().__init__(InitWeight, TypeLoss, **kwargs)
self.model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
tempFC_ft = self.model.classifier[1].in_features
self.model.classifier[1] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_Vit_b_16(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
super().__init__(InitWeight, TypeLoss)
self.model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
tempFC_ft = self.model.heads[0].in_features
self.model.heads[0] = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
class RARP_NVB_DenseNet169(RARP_NVB_Model):
def __init__(self, InitWeight=None, TypeLoss=TypeLossFunction.CrossEntropy) -> None:
super().__init__(InitWeight, TypeLoss)
self.model = torchvision.models.densenet169(weights=torchvision.models.DenseNet169_Weights.DEFAULT)
inFeatures = self.model.classifier.in_features
self.model.classifier = torch.nn.Linear(in_features=inFeatures, out_features=1)
class RARP_NVB_RestNet50_old(L.LightningModule):
def __init__(self, InitialWeigth = 1) -> None:
super().__init__()
self.model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
tempFC_ft = self.model.fc.in_features
self.model.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=1)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
self.lossFN = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([InitialWeigth]))
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def training_step(self, batch, batch_idx):
img, label = batch
label = label.float()
prediction = self(img)[:,0]
loss = self.lossFN(prediction, label)
self.log("Train Loss", loss)
self.log("Step Train Acc", self.train_acc(torch.sigmoid(prediction), label.int()))
return loss
def on_train_epoch_end(self):
self.log("Train Acc", self.train_acc.compute())
def validation_step(self, batch, batch_idx):
img, label = batch
label = label.float()
prediction = self(img)[:,0]
loss = self.lossFN(prediction, label)
self.log("Val Loss", loss)
self.log("Step Val Acc", self.val_acc(torch.sigmoid(prediction), label.int())) #train_acc
return loss
def on_validation_epoch_end(self):
self.log("Val Acc", self.val_acc.compute())
def configure_optimizers(self):
return [self.optimizer]