diff --git a/Loaders.py b/Loaders.py index 0b9d64b..0f6199c 100644 --- a/Loaders.py +++ b/Loaders.py @@ -1059,6 +1059,72 @@ return all_Crops +class RARP_DINO_Albumentations(): + def __init__( + self, + GloblaCropsScale=(0.4, 1), + LocalCropsScale=(0.05, 0.4), + NumLocalCrops = 8, + Size = 224, + device=None, + mean = None, + std = None, + Tranform_0=None, + Init_Resize=(512, 512), + Seed=505 + ): + self.NumLocal_Crops= NumLocalCrops + + self.globalCrop1 = A.Compose([ + A.Resize(Init_Resize[0], Init_Resize[1], interpolation=cv2.INTER_CUBIC), + A.Affine(scale=GloblaCropsScale, rotate=(-15, 15), interpolation=cv2.INTER_CUBIC), + A.RandomCrop(Size, Size), + A.HorizontalFlip(0.6), + A.ColorJitter(brightness=1.1, contrast=0.4, saturation=0.2, hue=0.1, p=0.8), + A.ToGray(p=0.4), + A.GaussianBlur(p=1), + A.RandomFog(p=0.5), + A.Normalize(mean, std), + ToTensorV2() + ], seed=Seed) + + self.globalCrop2 = A.Compose([ + A.Resize(Init_Resize[0], Init_Resize[1], interpolation=cv2.INTER_CUBIC), + A.Affine(scale=GloblaCropsScale, rotate=(-15, 15), interpolation=cv2.INTER_CUBIC), + A.RandomCrop(Size, Size), + A.HorizontalFlip(0.5), + A.ColorJitter(brightness=1.1, contrast=0.4, saturation=0.2, hue=0.1, p=0.8), + A.ToGray(p=0.4), + A.GaussianBlur(p=1), + A.Solarize(p=0.3), + A.RandomFog(p=0.5), + A.Normalize(mean, std), + ToTensorV2() + ], seed=Seed) + + self.local = A.Compose([ + A.Resize(Init_Resize[0], Init_Resize[1], interpolation=cv2.INTER_CUBIC), + A.Affine(scale=LocalCropsScale, rotate=(-15, 15), interpolation=cv2.INTER_CUBIC), + A.RandomCrop(Size, Size), + A.ColorJitter(brightness=1.1, contrast=0.4, saturation=0.2, hue=0.1, p=0.8), + A.Solarize(p=0.3), + A.RandomFog(p=0.5), + A.Normalize(mean, std), + ToTensorV2() + ], seed=Seed) + + self.classification = Tranform_0 + + def __call__(self, img): + all_Crops = [] + all_Crops.append(self.classification(image=img)["image"]) + all_Crops.append(self.globalCrop1(image=img)["image"]) + all_Crops.append(self.globalCrop2(image=img)["image"]) + + all_Crops.extend([self.local(image=img)["image"] for _ in range(self.NumLocal_Crops)]) + + return all_Crops + class RARP_DataSetType(Enum): train = 0 val = 1 diff --git a/Models.py b/Models.py index 5c5bbea..a2e1b98 100644 --- a/Models.py +++ b/Models.py @@ -3,12 +3,14 @@ import torch import torchvision import torchmetrics +import torchmetrics.classification import lightning as L from enum import Enum import timm import van import numpy as np from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt +from noah import NOAH def js_divergence_sigmoid(p_logits, q_logits): @@ -92,15 +94,83 @@ # Final layer to get the output image layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1)) - layers.append(torch.nn.Sigmoid()) # To get pixel values between 0 and 1 + #layers.append(torch.nn.Sigmoid()) # To get pixel values between 0 and 1 # Combine all layers into a Sequential module self.decoder = torch.nn.Sequential(*layers) def forward(self, x): return self.decoder(x) + +class DecoderUnet(torch.nn.Module): + def __init__(self, input_channels=2048, output_channels=3): + super().__init__() + self.dropout = torch.nn.Dropout2d(0.4) + self.upConv_0 = torch.nn.ConvTranspose2d(input_channels, 512, kernel_size=2, stride=2) + self.decoder_0 = self._conv_block(1024, 512) + + self.upConv_1 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2) + self.decoder_1 = self._conv_block(640, 320) + + self.upConv_2 = torch.nn.ConvTranspose2d(320, 128, kernel_size=2, stride=2) + self.decoder_2 = self._conv_block(256, 128) + + self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) + self.decoder_3 = self._conv_block(128, 64) + + self.upConv_4 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) + self.decoder_4 = self._conv_block(32, 16) + + self.enc3_upsample = torch.nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2) + self.enc2_upsample = torch.nn.ConvTranspose2d(320, 320, kernel_size=2, stride=2) + self.enc1_upsample = torch.nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2) + self.enc0_upsample = torch.nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2) + + self.last_conv = torch.nn.Conv2d(16, output_channels, kernel_size=1) + + def _conv_block(self, in_ch, out_ch): + return torch.nn.Sequential( + torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + torch.nn.SiLU(inplace=True), + torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), + torch.nn.SiLU(inplace=True), + ) + + def forward(self, x): + encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = x + + decoder_l3 = self.upConv_0(btlneck) + encoder_l3 = self.enc3_upsample(encoder_l3) + decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) + decoder_l3 = self.decoder_0(decoder_l3) + + decoder_l3 = self.dropout(decoder_l3) + + decoder_l2 = self.upConv_1(decoder_l3) + encoder_l2 = self.enc2_upsample(encoder_l2) + decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) + decoder_l2 = self.decoder_1(decoder_l2) + + decoder_l2 = self.dropout(decoder_l2) + + decoder_l1 = self.upConv_2(decoder_l2) + encoder_l1 = self.enc1_upsample(encoder_l1) + decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) + decoder_l1 = self.decoder_2(decoder_l1) + + decoder_l1 = self.dropout(decoder_l1) + + decoder_l0 = self.upConv_3(decoder_l1) + encoder_l0 = self.enc0_upsample(encoder_l0) + decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1) + decoder_l0 = self.decoder_3(decoder_l0) + + decoder_last = self.upConv_4(decoder_l0) + decoder_last = self.decoder_4(decoder_last) + + return self.last_conv(decoder_last) class ModelsList(Enum): RestNet50 = 1 @@ -244,6 +314,177 @@ cam = (cam - cam.min()) / (cam.max() - cam.min()) return cam +class UNet_RN18(torch.nn.Module): + def _conv_block(self, in_ch, out_ch): + return torch.nn.Sequential( + torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + torch.nn.SiLU(inplace=True), + torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), + torch.nn.SiLU(inplace=True), + ) + + def _hook_fn(self, module, input, output): + self.feature_maps.append(output) + + def _register_encoder_hooks(self): + for layer in self.list_blocks: + self.hooks.append(layer.register_forward_hook(self._hook_fn)) + + def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.hooks = [] + self.feature_maps = [] + + self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT) + self.encoder_base.fc = torch.nn.Identity() + + #for parms in self.encoder_base.parameters(): + # parms.requires_grad = False + + self.list_blocks = [ + self.encoder_base.conv1, + self.encoder_base.layer1, + self.encoder_base.layer2, + self.encoder_base.layer3, + self.encoder_base.layer4 + ] + + self._register_encoder_hooks() + + self.dropout = torch.nn.Dropout2d(0.4) + + self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) + self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) + self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) + self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) + + self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2) + + self.decoder_0 = self._conv_block(512, 256) #0 + self.decoder_1 = self._conv_block(256, 128) #1 + self.decoder_2 = self._conv_block(128, 64) #2 + self.decoder_3 = self._conv_block(32 + 64, 32) #3 + + self.decoder_extra = self._conv_block(16, 16) + + self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1) + + def forward(self, x): + self.feature_maps = [] + + _ = self.encoder_base(x) # forward to encoder and call hooks + + encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps + + decoder_l3 = self.upConv_0(btlneck) + decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) + decoder_l3 = self.decoder_0(decoder_l3) + + decoder_l3 = self.dropout(decoder_l3) + + decoder_l2 = self.upConv_1(decoder_l3) + decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) + decoder_l2 = self.decoder_1(decoder_l2) + + decoder_l2 = self.dropout(decoder_l2) + + decoder_l1 = self.upConv_2(decoder_l2) + decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) + decoder_l1 = self.decoder_2(decoder_l1) + + decoder_l1 = self.dropout(decoder_l1) + + decoder_l0 = self.upConv_3(decoder_l1) + decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1) + decoder_l0 = self.decoder_3(decoder_l0) + + decoder_last = self.upConv_extra(decoder_l0) + decoder_last = self.decoder_extra(decoder_last) + + return self.last_conv(decoder_last) + +class RARP_NVB_ROI_Mask_Unet(L.LightningModule): + def __init__(self,*args, **kwargs): + super().__init__(*args, **kwargs) + + self.model = UNet_RN18(in_channels=3, out_channels=1) + + self.lr = 1E-4 + self.Lambda_L1 = None + self.lossFN = torch.nn.BCEWithLogitsLoss() + + self.train_IoU = torchmetrics.classification.BinaryJaccardIndex() + self.val_IoU = torchmetrics.classification.BinaryJaccardIndex() + + def forward(self, data): + data = data.float() + pred = self.model(data) + return pred + + def _shared_step(self, batch, val_step:bool = True): + img, mask = batch + + mask = mask.float() + mask = mask.unsqueeze(1) + prediction = self(img) + + loss = self.lossFN(prediction, mask) + + predicted_labels = torch.sigmoid(prediction) + + if not val_step: + if self.Lambda_L1 is not None: + loss_l1 = 0 + for name, params in self.model.named_parameters(): + if "decoder" in name or "upConv" in name: + loss_l1 += torch.norm(params, p=1) + + loss += self.Lambda_L1 * loss_l1 + + return loss, mask, predicted_labels + + def training_step(self, batch, batch_idx): + loss, true_labels, predicted_labels = self._shared_step(batch, False) + + self.train_IoU.update(predicted_labels, true_labels) + + self.log("train_loss", loss, on_epoch=True) + self.log("train_acc_IoU", self.train_IoU, on_epoch=True, on_step=False) + + return loss + + def on_after_backward(self): + total_norm = 0.0 + for p in self.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** 0.5 + + self.log("grad_norm", total_norm) + + if total_norm < 1e-8: + self.log("grad_warning", "Vanishing gradient suspected!") + + def on_train_epoch_start(self): + for parms in self.model.encoder_base.parameters(): + parms.requires_grad = (self.current_epoch % 2 == 0) + + + def validation_step(self, batch, batch_idx): + loss, true_labels, predicted_labels = self._shared_step(batch) + + self.val_IoU.update(predicted_labels, true_labels) + + self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True) + self.log("val_acc_IoU", self.val_IoU, on_epoch=True, on_step=False) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + + return [optimizer] + class RARP_NVB_ResNet50_CAM(L.LightningModule): def __init__(self) -> None: super().__init__() @@ -871,8 +1112,8 @@ self.mlp = torch.nn.Linear(in_dim, bottleneck) else: layers = [torch.nn.Linear(in_dim, hidden_dim)] - layers.append(torch.nn.Dropout(0.30)) layers.append(self.activationFN) + layers.append(torch.nn.Dropout(0.30)) for _ in range(n_layers - 2): layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) layers.append(self.activationFN) @@ -901,7 +1142,7 @@ return x class RARP_NVB_DINO_Wrapper(torch.nn.Module): - def __init__(self, backbone, new_head, *args, **kwargs) -> None: + def __init__(self, backbone:torch.nn.Module, new_head:torch.nn.Module, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.backbone = backbone self.head = new_head @@ -1205,6 +1446,32 @@ for parms in self.teacher_Features.parameters(): parms.requires_grad = False +class RARP_NVB_Classification_Head(torch.nn.Module): + def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs): + super().__init__(*args, **kwargs) + + self.activation = activation_fn + + if len (layer) == 0: + self.head = torch.nn.Linear(in_features, out_features) + else: + temp_head = [] + next_input = in_features + for num in layer: + temp_head.append(torch.nn.Linear(next_input, num)) + temp_head.append(self.activation) + temp_head.append(torch.nn.Dropout(0.4)) + next_input = num + + temp_head[-1] = torch.nn.Dropout(0.2) + temp_head.append(torch.nn.Linear(next_input, out_features)) + + self.head = torch.nn.Sequential(*temp_head) + del temp_head + + def forward(self, x): + return self.head(x) + class RARP_NVB_DINO_MultiTask(L.LightningModule): # Define a hook function to capture the output def _hook_fn_Student(self, module, input, output): @@ -1223,16 +1490,23 @@ std: float = None, mean: float = None, SoftAdptAlgo:int = 0, - SoftAdptBeta:float = 0.1 + SoftAdptBeta:float = 0.1, + Teacher_T:float = 0.04, + Student_T:float = 0.1, + intermittent:bool = False ) -> None: super().__init__() + self.intermittent_train = intermittent + self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None self.lr = lr self.Lambda_L1 = L1 self.Lambda_L2 = L2 + self.Teacher_t = Teacher_T + self.Studet_T = Student_T self.momentum_teacher = momentum_teacher self.out_dim = 1024 self.in_dim = 512 @@ -1253,7 +1527,7 @@ self.student = van.van_b2(pretrained = True, num_classes = 0) self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0) - self.decoder = DynamicDecoder(input_channels=1024) #Decoder(input_channels=1024, hidden_channels=[512, 256, 128, 64]) + self.decoder = DynamicDecoder(input_channels=1024) self.student = RARP_NVB_DINO_Wrapper( self.student, @@ -1269,7 +1543,7 @@ for parms in self.teacher_Features.parameters(): parms.requires_grad = False - self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, 0.04, 0.1, momentum_teacher) + self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher) self.lossFN = torch.nn.CrossEntropyLoss(label_smoothing=0.5) if TypeLoss == TypeLossFunction.CrossEntropy else torch.nn.BCEWithLogitsLoss() self.ReconstructionLoss = torch.nn.MSELoss() @@ -1281,16 +1555,16 @@ self.clasiffier = torch.nn.Sequential( torch.nn.Linear(1024, 128), - torch.nn.Dropout(0.4), torch.nn.SiLU(True), + torch.nn.Dropout(0.4), torch.nn.Linear(128, 8), - torch.nn.Dropout(0.2), torch.nn.SiLU(True), + torch.nn.Dropout(0.2), torch.nn.Linear(8, 1) ) - + print(f"lr={self.lr}, L1={self.Lambda_L1}") def _denormalize(self, tensor:torch.Tensor): @@ -1299,6 +1573,12 @@ std = self.std_IMG.to(tensor.device) return tensor * std + mean + def _calc_L1(self, params): + l1 = 0 + for p in params: + l1 += torch.sum(torch.abs(p)) + return self.Lambda_L1 * l1 + def _calc_weights(self, log_weights:bool = True): self.weights = self.softAdapt.get_component_weights( torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]), @@ -1349,7 +1629,12 @@ prediction, features, new_image = self(img, val_step) StudentF, TeacherF = features - prediction = prediction.flatten() + if isinstance(self.clasiffier, torch.nn.Sequential): + if self.clasiffier[-1].out_features == 1: + prediction = prediction.flatten() + elif isinstance(self.clasiffier, (NOAH, RARP_NVB_Classification_Head)): + prediction = prediction.flatten() + predicted_labels = torch.sigmoid(prediction) orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float() @@ -1365,10 +1650,7 @@ if not val_step: if self.Lambda_L1 is not None: - loss_l1 = 0 - for params in self.clasiffier.parameters(): # aqui - loss_l1 += torch.sum(torch.abs(params)) - loss_HL += self.Lambda_L1 * loss_l1 + loss_HL += self._calc_L1(self.clasiffier.parameters()) if self.Lambda_L2 > 0: l2_reg = 0.0 @@ -1388,6 +1670,19 @@ if self.current_epoch % 2 == 0 and self.current_epoch != 0: self._calc_weights() + if self.intermittent_train and self.current_epoch != 0: + + par_epoch = (self.current_epoch % 2 == 0) + + for parms in self.student.backbone.parameters(): + parms.requires_grad = par_epoch + + for parms in self.decoder.parameters(): + parms.requires_grad = not par_epoch + + for parms in self.clasiffier.parameters(): + parms.requires_grad = not par_epoch + def training_step(self, batch, batch_idx): loss, true_labels, predicted_labels, losses = self._shared_step(batch, False) @@ -1414,7 +1709,19 @@ teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data) #self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center) - + + def on_after_backward(self): + total_norm = 0.0 + for p in self.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** 0.5 + + self.log("grad_norm", total_norm) + + if total_norm < 1e-8: + self.log("grad_warning", "Vanishing gradient suspected!") def validation_step(self, batch, batch_idx): loss, true_labels, predicted_labels, losses = self._shared_step(batch, True) @@ -1446,6 +1753,438 @@ return [optimizer] +#Ablation Models +"""T-S Multi-task model With out Recostruccion (V3R1_A1) + +Returns: + LightningModule +""" +class RARP_NVB_DINO_MultiTask_A1(RARP_NVB_DINO_MultiTask): + def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False): + super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent) + + self.decoder = torch.nn.Identity() + del self.ReconstructionLoss + + self.mean_IMG = None + self.std_IMG = None + + self.weights = torch.tensor([1,1]) + self.loss_history = { + 'loss_DINO': [], + 'loss_Binary': [], + } + + def _calc_weights(self, log_weights:bool = True): + self.weights = self.softAdapt.get_component_weights( + torch.tensor(self.loss_history["loss_DINO"][:-1] if len(self.loss_history["loss_DINO"]) % 2 == 0 else self.loss_history["loss_DINO"]), + torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]), + verbose=False + ) + + if log_weights: + self.log("W_loss_img", 0, on_epoch=True, on_step=False) + self.log("W_loss_DINO", self.weights[0], on_epoch=True, on_step=False) + self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False) + + self.loss_history = { + 'loss_DINO': [], + 'loss_Binary': [], + } + + def _shared_step(self, batch, val_step:bool = False): + img, label = batch + + prediction, features, new_image = self(img, val_step) + StudentF, TeacherF = features + + if isinstance(self.clasiffier, torch.nn.Sequential): + if self.clasiffier[-1].out_features == 1: + prediction = prediction.flatten() + elif isinstance(self.clasiffier, NOAH): + prediction = prediction.flatten() + + predicted_labels = torch.sigmoid(prediction) + + #orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float() + label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float() + + #DINO Loss + loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32) + #Clasificator + loss_HL = self.lossFN(prediction, label) + #Reconstruction + #loss_img = self.ReconstructionLoss(new_image, orignalImg) + #loss_img = loss_img.float() + + if not val_step: + if self.Lambda_L1 is not None: + loss_HL += self._calc_L1(self.clasiffier.parameters()) + + if self.Lambda_L2 > 0: + l2_reg = 0.0 + for param in self.clasiffier.parameters(): + l2_reg += torch.norm(param, 2) ** 2 + loss_HL += self.Lambda_L2 * l2_reg + + self.loss_history["loss_DINO"].append(loss_Dino.item()) + self.loss_history["loss_Binary"].append(loss_HL.item()) + + loss = self.weights[0] * loss_Dino + self.weights[1] * loss_HL + + return loss, label, predicted_labels, (self.weights[0] * loss_Dino, self.weights[1] * loss_HL, 0, new_image) + +"""T-S Multi-task model With out SoftAdadapt, Fix loss wegth 0.333_ (V3R1_A2) + +Returns: + LightningModule +""" +class RARP_NVB_DINO_MultiTask_A2(RARP_NVB_DINO_MultiTask): + def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False): + super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent) + + del self.softAdapt + self.weights = [1/3, 1/3, 1/3] + + def _calc_weights(self): + self.weights = [1/3, 1/3, 1/3] + + self.loss_history = { + 'loss_DINO': [], + 'loss_Reconstruction': [], + 'loss_Binary': [], + } + +"""S Multi-task model No Dino base encoder VAN_b2, (V3R1_A3_1) + +Returns: + LightningModule +""" +class RARP_NVB_DINO_MultiTask_A3(RARP_NVB_DINO_MultiTask): + def _hook_fn_Student(self, module, input, output): + self.last_conv_output_S = output + self.last_conv_output_T = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + if not self.val_phace: + self.last_conv_output_T = self.last_conv_output_T[:16] # Fixed bach of 8 + + def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False): + super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent) + + self.student = RARP_NVB_DINO_Wrapper( + van.van_b2(pretrained = True, num_classes = 0), + torch.nn.Identity() + ) + + self.teacher_Features = RARP_NVB_DINO_Wrapper( + torch.nn.Identity(), + torch.nn.Identity() + ) + + self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student) + + self.val_phace = True + + self.weights = torch.tensor([1,1]) + self.loss_history = { + 'loss_Reconstruction': [], + 'loss_Binary': [], + } + + def _calc_weights(self, log_weights:bool = True): + self.weights = self.softAdapt.get_component_weights( + torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_DINO"]), + torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]), + verbose=False + ) + + if log_weights: + self.log("W_loss_img", self.weights[0], on_epoch=True, on_step=False) + self.log("W_loss_DINO", 0, on_epoch=True, on_step=False) + self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False) + + self.loss_history = { + 'loss_Reconstruction': [], + 'loss_Binary': [], + } + + def _shared_step(self, batch, val_step:bool = False): + self.val_phace = val_step + + img, label = batch + + prediction, features, new_image = self(img, val_step) + _, TeacherF = features + + if isinstance(self.clasiffier, torch.nn.Sequential): + if self.clasiffier[-1].out_features == 1: + prediction = prediction.flatten() + elif isinstance(self.clasiffier, NOAH): + prediction = prediction.flatten() + + predicted_labels = torch.sigmoid(prediction) + + orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float() + label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float() + + #DINO Loss + #loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32) + #Clasificator + loss_HL = self.lossFN(prediction, label) + #Reconstruction + loss_img = self.ReconstructionLoss(new_image, orignalImg) + loss_img = loss_img.float() + + if not val_step: + if self.Lambda_L1 is not None: + loss_HL += self._calc_L1(self.clasiffier.parameters()) + + if self.Lambda_L2 > 0: + l2_reg = 0.0 + for param in self.clasiffier.parameters(): + l2_reg += torch.norm(param, 2) ** 2 + loss_HL += self.Lambda_L2 * l2_reg + + self.loss_history["loss_Reconstruction"].append(loss_img.item()) + self.loss_history["loss_Binary"].append(loss_HL.item()) + + loss = self.weights[0] * loss_img + self.weights[1] * loss_HL + + return loss, label, predicted_labels, (0, self.weights[1] * loss_HL, self.weights[0] * loss_img, new_image) + +"""S Multi-task model No Dino base encoder RN50, (V3R1_A3_2) + +Returns: + LightningModule +""" +class RARP_NVB_DINO_MultiTask_A3_RN50(RARP_NVB_DINO_MultiTask): + def _hook_fn_Student(self, module, input, output): + self.last_conv_output_S = output + self.last_conv_output_T = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + if not self.val_phace: + self.last_conv_output_T = self.last_conv_output_T[:16] # Fixed bach of 8 + + self.last_conv_output_T = self.last_conv_output_T[:, :0, :, :] + + def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False): + super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent) + + self.student = RARP_NVB_DINO_Wrapper( + torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT), + torch.nn.Identity() + ) + + self.teacher_Features = RARP_NVB_DINO_Wrapper( + torch.nn.Identity(), + torch.nn.Identity() + ) + + self.student.backbone.layer4.register_forward_hook(self._hook_fn_Student) + + self.val_phace = True + + self.weights = torch.tensor([1,1]) + self.loss_history = { + 'loss_Reconstruction': [], + 'loss_Binary': [], + } + + self.decoder = DynamicDecoder(input_channels=2048) + + self.clasiffier = torch.nn.Sequential( + torch.nn.Linear(2048, 128), + torch.nn.SiLU(True), + torch.nn.Dropout(0.4), + + torch.nn.Linear(128, 8), + torch.nn.SiLU(True), + torch.nn.Dropout(0.2), + + torch.nn.Linear(8, 1) + ) + + def _calc_weights(self, log_weights:bool = True): + self.weights = self.softAdapt.get_component_weights( + torch.tensor(self.loss_history["loss_Reconstruction"][:-1] if len(self.loss_history["loss_Reconstruction"]) % 2 == 0 else self.loss_history["loss_DINO"]), + torch.tensor(self.loss_history["loss_Binary"][:-1] if len(self.loss_history["loss_Binary"]) % 2 == 0 else self.loss_history["loss_Binary"]), + verbose=False + ) + + if log_weights: + self.log("W_loss_img", self.weights[0], on_epoch=True, on_step=False) + self.log("W_loss_DINO", 0, on_epoch=True, on_step=False) + self.log("W_loss_GT", self.weights[1], on_epoch=True, on_step=False) + + self.loss_history = { + 'loss_Reconstruction': [], + 'loss_Binary': [], + } + + def _shared_step(self, batch, val_step:bool = False): + self.val_phace = val_step + + img, label = batch + + prediction, features, new_image = self(img, val_step) + _, TeacherF = features + + if isinstance(self.clasiffier, torch.nn.Sequential): + if self.clasiffier[-1].out_features == 1: + prediction = prediction.flatten() + elif isinstance(self.clasiffier, NOAH): + prediction = prediction.flatten() + + predicted_labels = torch.sigmoid(prediction) + + orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float() + label = torch.cat([label.float() for _ in range(len(TeacherF))], dim=0) if not val_step else label.float() + + #DINO Loss + #loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32) + #Clasificator + loss_HL = self.lossFN(prediction, label) + #Reconstruction + loss_img = self.ReconstructionLoss(new_image, orignalImg) + loss_img = loss_img.float() + + if not val_step: + if self.Lambda_L1 is not None: + loss_HL += self._calc_L1(self.clasiffier.parameters()) + + if self.Lambda_L2 > 0: + l2_reg = 0.0 + for param in self.clasiffier.parameters(): + l2_reg += torch.norm(param, 2) ** 2 + loss_HL += self.Lambda_L2 * l2_reg + + self.loss_history["loss_Reconstruction"].append(loss_img.item()) + self.loss_history["loss_Binary"].append(loss_HL.item()) + + loss = self.weights[0] * loss_img + self.weights[1] * loss_HL + + return loss, label, predicted_labels, (0, self.weights[1] * loss_HL, self.weights[0] * loss_img, new_image) + +"""T-S Multi-task model, classification head layer change, (V3R1_A4) + +Returns: + LightningModule +""" +class RARP_NVB_DINO_MultiTask_A4(RARP_NVB_DINO_MultiTask): + def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False): + super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent) + #layeres = [128, 8] #L3 original + #layeres = [256, 128, 8] #L4 + #layeres = [8] #L2 + layeres = [] #L1 + self.clasiffier = RARP_NVB_Classification_Head(1024, 1, layeres, torch.nn.SiLU(True)) + +#end Ablation Models +class RARP_NVB_DINO_MultiTask_Unet(RARP_NVB_DINO_MultiTask): + def _encoder_hool_fn(self, module, input, output): + self.feature_maps.append(output) + + def _register_encoder_hooks(self): + for layer in self.list_blocks: + self.hooks.append(layer.register_forward_hook(self._encoder_hool_fn)) + + def __init__( + self, + TypeLoss=TypeLossFunction.CrossEntropy, + momentum_teacher:float = 0.9995, + lr:float = 1e-4, + L1:float = None, + L2:float = 0, + std: float = None, + mean: float = None, + SoftAdptAlgo:int = 0, + SoftAdptBeta:float = 0.1 + ): + super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta) + + self.hooks = [] + self.feature_maps = [] + + self.list_blocks = [ + self.student.backbone.block1[-1], + self.student.backbone.block2[-1], + self.student.backbone.block3[-1], + self.student.backbone.block4[-2], + ] + + self._register_encoder_hooks() + + self.decoder = DecoderUnet(1024) + + def forward(self, data, val_step:bool = True): + self.feature_maps = [] + + if val_step: + data = data.float() + dataTeacher, dataStudent = data, data + else: + data = [d.float() for d in data] + dataTeacher, dataStudent = data[1:3], data + + TeacherDino = self.teacher_Features(dataTeacher) + Student = self.student(dataStudent) + + if not val_step: + NumChunks = len(dataStudent) + S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3] + self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0) + + cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1) + + if not val_step: + NumChunks = len(dataStudent) + temp = [] + for f_maps in self.feature_maps: + temp.append(torch.cat(f_maps.chunk(NumChunks)[1:3], dim=0)) + self.feature_maps = temp + del temp + + self.feature_maps.append(cat_features) + reconstructed_image = self.decoder(self.feature_maps) + + Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1) + pred = self.clasiffier(Cont_Net) + + return pred, (Student, TeacherDino), reconstructed_image + +class RARP_NVB_DINO_MultiTask_MultiLabel(RARP_NVB_DINO_MultiTask): + def __init__( + self, + TypeLoss=TypeLossFunction.BCEWithLogits, + momentum_teacher = 0.9995, + lr = 0.0001, + L1 = None, + L2 = 0, + std = None, + mean = None, + SoftAdptAlgo = 0, + SoftAdptBeta = 0.1, + Num_Lables = 2 + ): + super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta) + + self.lossFN = torch.nn.BCEWithLogitsLoss() + + self.clasiffier = torch.nn.Sequential( + torch.nn.Linear(1024, 128), + torch.nn.Dropout(0.4), + torch.nn.SiLU(True), + + torch.nn.Linear(128, 8), + torch.nn.Dropout(0.2), + torch.nn.SiLU(True), + + torch.nn.Linear(8, Num_Lables) + ) + + self.train_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables) + self.val_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables) + self.test_acc = torchmetrics.Accuracy('multilabel', num_labels=Num_Lables) + self.f1ScoreTest = torchmetrics.F1Score('multilabel', num_labels=Num_Lables) + class RARP_NVB_DINO_MultiTask_v2(RARP_NVB_DINO_MultiTask): def __init__( self, diff --git a/RARP_NVB.py b/RARP_NVB.py index 6f77122..9226ed8 100644 --- a/RARP_NVB.py +++ b/RARP_NVB.py @@ -23,6 +23,9 @@ import yaml import optuna from optuna.integration import PyTorchLightningPruningCallback +from sklearn.metrics import confusion_matrix +from sklearn.metrics import ConfusionMatrixDisplay +import japanize_matplotlib torch.set_float32_matmul_precision('high') torch.backends.cudnn.deterministic = True @@ -118,7 +121,24 @@ return [acc.item(), precision.item(), recall.item(), f1Score.item(), auc.item()] -def Calc_EvalMulticlass_table(TrainModel:M.RARP_NVB_Model,TestDataLoadre:DataLoader, Youden=False, modelName="", NumClasses:int=2): +def encode_2labels_4classes (x:torch, th:float = 0.5): + if x.dtype == torch.float: + x = (x > 0.5) *1 + + r, l = x + + if r == 0 and l == 0: + return 0 + elif r == 1 and l == 0: + return 1 + elif r == 0 and l == 1: + return 2 + elif r == 1 and l == 1: + return 3 + else: + return -1 + +def Calc_EvalMulticlass_table(TrainModel:M.RARP_NVB_Model,TestDataLoadre:DataLoader, Youden=False, modelName="", NumClasses:int=2, Num_Label:int=None): TrainModel.to(device) TrainModel.eval() @@ -132,10 +152,10 @@ if isinstance(TrainModel, M.RARP_NVB_DINO_MultiTask): pred, _, _ = TrainModel(data) - NumClasses = 4 + NumClasses = 4 if Num_Label is None else None else: pred = TrainModel(data) - Predictions.append(torch.softmax(pred, dim=1)) + Predictions.append(torch.softmax(pred, dim=1) if Num_Label is None else torch.sigmoid(pred)) Labels.append(label) Predictions = torch.cat(Predictions) @@ -143,36 +163,45 @@ print(Predictions, Labels) - acc = torchmetrics.Accuracy("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels) - precision = torchmetrics.Precision("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels) - recall = torchmetrics.Recall("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels) - auc = torchmetrics.AUROC("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels) - f1Score = torchmetrics.F1Score("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels) - specificty = torchmetrics.Specificity("multiclass", num_classes=NumClasses).to(device)(Predictions, Labels) + acc = torchmetrics.Accuracy("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels) + precision = torchmetrics.Precision("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels) + recall = torchmetrics.Recall("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels) + auc = torchmetrics.AUROC("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels) + f1Score = torchmetrics.F1Score("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels) + specificty = torchmetrics.Specificity("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device)(Predictions, Labels) table = [ ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", f"*{modelName}*"] #.item():.4f ] - cm2 = torchmetrics.ConfusionMatrix("multiclass", num_classes=NumClasses).to(device) - cm2.update(Predictions, Labels) - _, ax = cm2.plot() - ax.set_title(f"NVB Classifier {modelName}") + if Num_Label is not None: + single_labels_pred = [encode_2labels_4classes(p.cpu()) for p in Predictions] + single_labels_true = [encode_2labels_4classes(p.cpu()) for p in Labels] + labels_names = ["なし", "右", "左", "右+左"] + cm = confusion_matrix(single_labels_true, single_labels_pred, labels=[0,1,2,3]) + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels_names) + disp.plot() + + + #cm2 = torchmetrics.ConfusionMatrix("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device) + #cm2.update(Predictions, Labels) + #_, ax = cm2.plot() + #ax.set_title(f"NVB Classifier {modelName}") if Youden: for i in range(2): - aucCurve = torchmetrics.ROC("multiclass", num_classes=NumClasses).to(device) + aucCurve = torchmetrics.ROC("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label).to(device) fpr, tpr, thhols = aucCurve(Predictions, Labels) index = torch.argmax(tpr - fpr) th2 = (recall + specificty - 1).item() th2 = 0.5 if th2 <= 0 else th2 th1 = thhols[index].item() if i == 0 else th2 - accY = torchmetrics.Accuracy("multiclass", num_classes=NumClasses, threshold=th1).to(device)(Predictions, Labels) - precisionY = torchmetrics.Precision("multiclass", num_classes=NumClasses, threshold=th1).to(device)(Predictions, Labels) - recallY = torchmetrics.Recall("multiclass", num_classes=NumClasses, threshold=th1).to(device)(Predictions, Labels) + accY = torchmetrics.Accuracy("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels) + precisionY = torchmetrics.Precision("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels) + recallY = torchmetrics.Recall("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels) specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels) - f1ScoreY = torchmetrics.F1Score("multiclass", num_classes=NumClasses, threshold=th1).to(device)(Predictions, Labels) - #cm2 = torchmetrics.ConfusionMatrix("multiclass", num_classes=NumClasses, threshold=th1).to(device) + f1ScoreY = torchmetrics.F1Score("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, num_labels=Num_Label, threshold=th1).to(device)(Predictions, Labels) + #cm2 = torchmetrics.ConfusionMatrix("multiclass" if Num_Label is None else "multilabel", num_classes=NumClasses, threshold=th1).to(device) #cm2.update(Predictions, Labels) #_, ax = cm2.plot() #ax.set_title(f"NVB Classifier (th={th1:.4f})") @@ -591,7 +620,8 @@ std=std, mean=mean, L1= 1.31E-04, - L2= 0 + L2= 0, + SoftAdptAlgo=0, ) if Ckpt_File is None else M.RARP_NVB_DINO_MultiTask.load_from_checkpoint(ckpFile) ModelCAM = None @@ -658,6 +688,8 @@ parser.add_argument("--DyLr", type=bool, default=False) parser.add_argument("-lr", type=float, default=1e-4) parser.add_argument("--ExtraNeurons", type=int, default=4) + parser.add_argument("--ExtraLabels", type=str, default=None) + parser.add_argument("--Roi_Mask_Model", type=str, default=None) args = parser.parse_args() @@ -880,10 +912,10 @@ cropSize = 720 case 15: Dataset = Loaders.RARP_DatasetCreator( - "./DataSet_Ando_ChageLabels", + "./DataSet", FoldSeed=505, createFile=True, - SavePath="./DataSet_C_L", + SavePath="./DataSet", Fold=5, removeBlackBar=args.Remove_Blackbar, RGBGama=args.BGR2RGB, @@ -895,10 +927,10 @@ cropSize = 720 case 16: Dataset = Loaders.RARP_DatasetCreator( - "./DataSet_Ando_ChageLabels_crop", + "./DataSet_crop", FoldSeed=505, createFile=True, - SavePath="./DataSet_Ando_Crop", + SavePath="./DataSet_Crop", Fold=5, removeBlackBar=args.Remove_Blackbar, RGBGama=args.BGR2RGB, @@ -923,6 +955,38 @@ colorSpace=args.ColorSpace ) cropSize = 720 + case 18: + Dataset = Loaders.RARP_DatasetCreator( + "./DataSet_Ando_All_no20Crop", + FoldSeed=505, + createFile=True, + SavePath="./DataSet_AndoAll20_crop", + Fold=5, + removeBlackBar=args.Remove_Blackbar, + RGBGama=args.BGR2RGB, + SegImage=args.imgSlice_pct, + Num_Img_Slices=args.Num_Slices, + SegmentClass=args.sClass, + colorSpace=args.ColorSpace + ) + cropSize = 720 + case 19: + ROI_model = M.RARP_NVB_ROI_Mask_Unet.load_from_checkpoint(Path(args.Roi_Mask_Model)) + Dataset = Loaders.RARP_DatasetCreator( + "./DataSet_Ando_All_no20", + FoldSeed=505, + createFile=True, + SavePath="./DataSet_AndoAll20_mask", + Fold=5, + removeBlackBar=args.Remove_Blackbar, + RGBGama=args.BGR2RGB, + SegImage=args.imgSlice_pct, + Num_Img_Slices=args.Num_Slices, + SegmentClass=args.sClass, + colorSpace=args.ColorSpace, + ROI_Mask=ROI_model + ) + cropSize = 720 case 5: YoloModel = YOLO(model="RARP_YoloV8_ROI.pt") Dataset = Loaders.RARP_DatasetCreator( @@ -956,11 +1020,11 @@ colorSpace=args.ColorSpace ) cropSize = 720 - + + Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373]) + Dataset.CreateFolds() - Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373]) - setup_seed(2023) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") batchSize = 8 #17 #8, 32 @@ -993,17 +1057,22 @@ transforms.Normalize(Dataset.mean, Dataset.std) ).to(device) + Roi_mask_transform = torch.nn.Sequential( + transforms.Resize((224, 224), antialias=True, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Normalize(Dataset.mean, Dataset.std) + ).to(device) + valtransform = torch.nn.Sequential( transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.Normalize(Dataset.mean, Dataset.std) - ).to(device) + ).to(device) #if not args.Model in [20, 21] else Roi_mask_transform testtransform = torch.nn.Sequential( transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.Normalize(Dataset.mean, Dataset.std) - ).to(device) + ).to(device) #if not args.Model in [20, 21] else Roi_mask_transform TrainDINOTransforms = Loaders.RARP_DINO_Augmentation( GloblaCropsScale = (0.4, 1), @@ -1122,6 +1191,39 @@ extensions="npy", transform=testtransform ) + + if args.ExtraLabels is not None: + DumpCSV = pd.read_csv(Dataset.CVS_File) + Extradata = pd.read_excel(Path(args.ExtraLabels)) + + DumpCSV["raw_name"] = "Img0-" + DumpCSV["id"].astype(str) + ".npy" + DumpCSV = DumpCSV.drop(columns=["mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3", "path", "class", "label"]) + + outPut = pd.merge(Extradata, DumpCSV, on="name", how="right") + + trainDataset = Loaders.RARP_DatasetFolder_ExtraLabel( + str (rootFile/"train"), + loader=defs.load_file_tensor, + Extra_Data=outPut, + extensions="npy", + transform=traintransform + ) + + valDataset = Loaders.RARP_DatasetFolder_ExtraLabel( + str (rootFile/"val"), + loader=defs.load_file_tensor, + Extra_Data=outPut, + extensions="npy", + transform=valtransform + ) + + testDataset = Loaders.RARP_DatasetFolder_ExtraLabel( + str (rootFile/"test"), + loader=defs.load_file_tensor, + Extra_Data=outPut, + extensions="npy", + transform=testtransform + ) Train_DataLoader = DataLoader( trainDataset, @@ -1215,7 +1317,7 @@ logger=CSVLogger(save_dir=f"./{LogFileName}", name="Tune") if args.Phase == "tune" else TensorBoardLogger(save_dir=f"./{LogFileName}"), log_every_n_steps=5, #callbacks=[checkPtCallback, StepDropout(5, base_drop_rate=0.2, gamma=0.05, ascending=True)],#if args.Model == 4 else checkPtCallback, - callbacks=[checkPtCallback, callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)], #ckpLossBest, ], + callbacks=[checkPtCallback, callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)], max_epochs=MaxEpochs, ) print("Train Phase") @@ -1246,8 +1348,10 @@ #ViewImgDINO(trainDataset, Dataset.std, Dataset.mean) - if isinstance(Model, (M.RARP_NVB_MultiClassModel, M.RARP_NVB_DINO_MultiTask_v2)): - temp = Calc_EvalMulticlass_table(Model, Test_DataLoader, False, ckpFile.name, NumClasses=4) + if isinstance(Model, (M.RARP_NVB_MultiClassModel, M.RARP_NVB_DINO_MultiTask_v2, M.RARP_NVB_DINO_MultiTask_MultiLabel)): + numClass = 4 if isinstance(Model, M.RARP_NVB_DINO_MultiTask_v2) else 2 + numLabel = 2 if isinstance(Model, M.RARP_NVB_DINO_MultiTask_MultiLabel) else None + temp = Calc_EvalMulticlass_table(Model, Test_DataLoader, False, ckpFile.name, NumClasses=numClass, Num_Label=numLabel) else: temp = Calc_Eval_table( Model,