diff --git a/Models.py b/Models.py index a2e1b98..ad6a74e 100644 --- a/Models.py +++ b/Models.py @@ -6,11 +6,15 @@ import torchmetrics.classification import lightning as L from enum import Enum +from sklearn.preprocessing import LabelEncoder + + import timm import van import numpy as np from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt from noah import NOAH +import piq def js_divergence_sigmoid(p_logits, q_logits): @@ -76,7 +80,7 @@ return x class DynamicDecoder(torch.nn.Module): - def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64]): + def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], drop_out:float = None): super(DynamicDecoder, self).__init__() # Ensure the number of hidden channels matches the number of blocks @@ -90,6 +94,8 @@ layers.append(torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)) layers.append(torch.nn.BatchNorm2d(out_channels)) layers.append(torch.nn.ReLU(inplace=True)) + if drop_out is not None: + layers.append(torch.nn.Dropout(drop_out)) in_channels = out_channels # Final layer to get the output image @@ -1168,7 +1174,7 @@ self.C_Momentum = center_momentum self.register_buffer("center", torch.zeros(1, out_dim)) - def forward(self, s_Output, t_Output): + def forward(self, s_Output, t_Output, validation_step:bool = False): sTemp = [s / self.S_Thao for s in s_Output] tTemp = [(t - self.center) / self.T_Thao for t in t_Output] @@ -1188,7 +1194,9 @@ n_loss_terms += 1 total_loss /= n_loss_terms - self.update_center(t_Output) + + if not validation_step: + self.update_center(t_Output) return total_loss @@ -1472,6 +1480,119 @@ def forward(self, x): return self.head(x) +class RARP_Encoder_DINO(L.LightningModule): + def __init__(self, + momentum_teacher:float = 0.9995, + lr:float = 1e-4, + Teacher_T:float = 0.04, + Student_T:float = 0.1, + max_epochs:int = 100, + total_steps:int = None + ) -> None: + super().__init__() + self.save_hyperparameters() + + self.lr = lr + self.momentum_teacher = momentum_teacher + self.out_dim = 65536 + self.in_dim = 512 + + self.student = van.van_b1(num_classes = 0) + self.teacher = van.van_b1(num_classes = 0) + + self.student = RARP_NVB_DINO_Wrapper( + self.student, + RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True) + ) + + self.teacher = RARP_NVB_DINO_Wrapper( + self.teacher, + RARP_NVB_MLP(self.in_dim, self.out_dim, hidden_dim=2048, bottleneck=256, norm_last_layer=True) + ) + + self.teacher.load_state_dict(self.student.state_dict()) + + for parms in self.teacher.parameters(): + parms.requires_grad = False + + self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher) + + def forward(self, data, val_step=False): + if val_step: + dataTeacher, dataStudent = data.float(), data.float() + else: + data = [d.float() for d in data] + dataTeacher, dataStudent = data[0:3], data + + teacher_features = self.teacher(dataTeacher) + student_features = self.student(dataStudent) + + return teacher_features, student_features + + def _shared_step(self, batch, val_step=False): + img, _ = batch + t, s = self(img, val_step) + + loss_Dino = self.lossFN_DINO(s, t, validation_step=val_step) + + return loss_Dino + + def training_step(self, batch, batch_idx): + loss = self._shared_step(batch, False) + + self.log("train_loss", loss, on_epoch=True, sync_dist=True) + + return loss + + def validation_step(self, batch, batch_idx): + loss = self._shared_step(batch, True) + self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True) + + def on_train_batch_end(self, outputs, batch, batch_idx): + #step = self.global_step + # + #m = 1.0 - (1.0 - self.hparams.momentum_teacher) * ( + # (1 + math.cos(math.pi * step / self.hparamstotal_steps)) / 2 + #) + + with torch.no_grad(): + for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher.parameters()): + teacher_ps.data.mul_(self.momentum_teacher) + teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data) + + def on_after_backward(self): + total_norm = 0.0 + for p in self.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** 0.5 + + self.log("grad_norm", total_norm) + + if total_norm < 1e-8: + self.log("grad_warning", "Vanishing gradient suspected!") + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.student.parameters(), lr=self.lr) + + #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, + # T_max=self.hparams.max_epochs, # decays from epoch 0 → epoch max_epochs + # eta_min=0.0 + #) + + #return { + # "optimizer": optimizer, + # "lr_scheduler": { + # "scheduler": scheduler, + # "interval": "epoch", # <-- step once per epoch + # "frequency": 1, + # }, + #} + + return [optimizer] + class RARP_NVB_DINO_MultiTask(L.LightningModule): # Define a hook function to capture the output def _hook_fn_Student(self, module, input, output): @@ -1524,8 +1645,14 @@ self.test_acc = torchmetrics.Accuracy('binary') self.f1ScoreTest = torchmetrics.F1Score('binary') - self.student = van.van_b2(pretrained = True, num_classes = 0) - self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0) + #self.student = van.van_b2(pretrained = True, num_classes = 0) + #self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0) + + self.student = torchvision.models.convnext_small(weights=torchvision.models.ConvNeXt_Small_Weights.DEFAULT) + self.student.classifier[-1] = torch.nn.Identity() + + self.teacher_Features = torchvision.models.convnext_small(weights=torchvision.models.ConvNeXt_Small_Weights.DEFAULT) + self.teacher_Features.classifier[-1] = torch.nn.Identity() self.decoder = DynamicDecoder(input_channels=1024) @@ -1745,6 +1872,12 @@ if self.mean_IMG is not None and self.std_IMG is not None: imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1) imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :] + + imgOrig = torch.clip(self._denormalize(batch[0])/255, 0, 1) + imgOrig = imgOrig[:, [2, 1, 0], :, :] + + imgReconstruction = torch.cat((imgOrig, imgReconstruction), dim=0) + grid = torchvision.utils.make_grid(imgReconstruction) self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step) @@ -1752,6 +1885,68 @@ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) #, weight_decay=self.Lambda_L2 return [optimizer] + +class RARP_Hybrid_TS_LR(torch.nn.Module): + def __init__( + self, + base_TS_Model:str = "", + std: float = None, + mean: float = None, + stretch: bool = False, + masked: bool = False + ): + super().__init__() + self.mean = mean + self.std = std + self.stretch = stretch + self.masked = masked + + self.mid_length = 0 + + self.labels = ["L", "R"] + + self.baseModel = RARP_NVB_DINO_MultiTask.load_from_checkpoint(base_TS_Model) + self.baseModel.eval() + + def _mask_LR(self, image:torch.Tensor, Left:bool= True): + halfImg = image[:, :, :, :self.mid_length] if Left else image[:, :, :, self.mid_length:] + pad_zeros = torch.zeros_like(halfImg) #Agv. Color + listImgs = [halfImg, pad_zeros] if Left else [pad_zeros, halfImg] + return torch.cat(listImgs, dim=-1) + + def _crop_LR(self, image:torch.Tensor, Left:bool = True): + if Left: + return image[:, :, :, :self.mid_length] if not self.stretch else torch.nn.functional.interpolate( + image[:, :, :, :self.mid_length], + size=(224, 224), + mode='bicubic', + align_corners=False + ) + else: + return image[:, :, :, self.mid_length:] if not self.stretch else torch.nn.functional.interpolate( + image[:, :, :, self.mid_length:], + size=(224, 224), + mode='bicubic', + align_corners=False + ) + + def forward(self, x): + _, _, _, w = x.shape #[B, C, H, W] + self.mid_length = w // 2 + + LR_Img = { + "L":self._crop_LR(x, True) if not self.masked else self._mask_LR(x, True), + "R":self._crop_LR(x, False) if not self.masked else self._mask_LR(x, False) + } + + pred = [] + for label in self.labels: + with torch.no_grad(): + raw_pred, _, _ = self.baseModel(LR_Img[label]) + pred.append(raw_pred) + + return torch.cat(pred, dim=-1) + #Ablation Models """T-S Multi-task model With out Recostruccion (V3R1_A1) @@ -2078,12 +2273,410 @@ self.clasiffier = RARP_NVB_Classification_Head(1024, 1, layeres, torch.nn.SiLU(True)) #end Ablation Models + +class RARP_CLIP_loss(torch.nn.Module): + def __init__(self, temperature): + super().__init__() + self.temp = temperature + + def forward(self, z_s:torch.Tensor, z_t:torch.Tensor): + logits = torch.matmul(z_s, z_t.t()) / self.temp + lables = torch.arange(z_s.size(0), device=logits.device) + + loss_s2t = torch.nn.functional.cross_entropy(logits, lables) + loss_t2s = torch.nn.functional.cross_entropy(logits.t(), lables) + + return 0.5 * (loss_s2t + loss_t2s) + +class RARP_CLIP(L.LightningModule): + def __init__( + self, + student_backbone: str = "", + teacher_backbone: str = "", + proj_dim: int = 256, + embeddings: int = 512, + temperature: float = 0.07, + lr: float = 1e-4, + ): + super().__init__() + + self.save_hyperparameters() + + match(student_backbone): + case "van_b1": + self.student = van.van_b1(pretrained=False, num_classes=0) + self.student_dim = 512 + case _: + raise Exception(f"{student_backbone} Not Implemented") + + if len(teacher_backbone) > 0: + self.teacher = van.van_b2(pretrained=False, num_classes=0) + self.teacher.load_state_dict(torch.load(teacher_backbone)) + self.teacher_dim = 512 + else: + self.teacher = van.van_b2(pretrained=True, num_classes=0) + self.teacher_dim = 512 + + for p in self.teacher.parameters(): + p.requires_grad = False + + self.proj_s = torch.nn.Sequential( + torch.nn.Linear(self.student_dim, proj_dim), + torch.nn.LayerNorm(proj_dim), + torch.nn.GELU(), + torch.nn.Linear(proj_dim, embeddings) + ) + + self.loss_fn = RARP_CLIP_loss(temperature) + + def forward(self, data): + x_s = self.student(data) + x_s = self.proj_s(x_s) + x_s = torch.nn.functional.normalize(x_s, dim=-1) + + x_t = self.teacher(data) + x_t = torch.nn.functional.normalize(x_t, dim=-1) + + return x_s, x_t + + def _shared_step(self, batch): + img, _ = batch + z_s, z_t = self(img) + + loss = self.loss_fn(z_s, z_t) + + return loss + + def training_step(self, batch, batch_idx): + loss = self._shared_step(batch) + + self.log("train/clip_loss", loss, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def configure_optimizers(self): + params = ( + list(self.student.parameters()) + + list(self.proj_s.parameters()) + #list(self.proj_t.parameters()) # now included + ) + + return torch.optim.AdamW(params, lr=self.hparams.lr) + +class DecoderBlock(torch.nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv_expand = torch.nn.Conv2d(in_ch, out_ch*4, 3, padding=1) + self.pixel_shuffle = torch.nn.PixelShuffle(2) + self.bn1 = torch.nn.BatchNorm2d(out_ch) + self.act1 = torch.nn.GELU() + self.conv_refine = torch.nn.Conv2d(out_ch, out_ch, 3, padding=1) + self.bn2 = torch.nn.BatchNorm2d(out_ch) + self.act2 = torch.nn.GELU() + + def forward(self, x): + x = self.conv_expand(x) + x = self.pixel_shuffle(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv_refine(x) + x = self.bn2(x) + x = self.act2(x) + return x + +class DynamicDecoder_PixelShuffle(torch.nn.Module): + def __init__(self, input_channels=2048, output_channels=3, num_blocks=4, hidden_channels=[1024, 512, 256, 64], drop_out:float = None): + super().__init__() + + # Ensure the number of hidden channels matches the number of blocks + assert len(hidden_channels) == num_blocks, "Number of hidden channels must match the number of blocks." + + layers = [] + in_channels = input_channels + + # Loop to create the decoder blocks + for out_channels in hidden_channels: + layers.append(DecoderBlock(in_channels, out_channels)) + if drop_out is not None: + layers.append(torch.nn.Dropout(drop_out)) + in_channels = out_channels + + # Final layer to get the output image + #layers.append(torch.nn.Conv2d(in_channels, output_channels, kernel_size=3, padding=1)) + layers.append(torch.nn.ConvTranspose2d(in_channels, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1)) + #layers.append(torch.nn.Sigmoid()) # To get pixel values between 0 and 1 + + # Combine all layers into a Sequential module + self.decoder = torch.nn.Sequential(*layers) + + def forward(self, x): + return self.decoder(x) + +class RARP_NVB_DINO_MultiTask_Pretrain(RARP_NVB_DINO_MultiTask): + def __init__( + self, + TypeLoss=TypeLossFunction.CrossEntropy, + momentum_teacher = 0.9995, + lr = 0.0001, + L1 = None, + L2 = 0, + std = None, + mean = None, + SoftAdptAlgo = 0, + SoftAdptBeta = 0.1, + Teacher_T = 0.04, + Student_T = 0.1, + intermittent = False, + pre_train_pth:str = "", + HParameter = {}, + ): + super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent) + + #self.Lambda_L1 = getNVL(HParameter, "L1", 1.31E-04) + #self.lr = getNVL(HParameter, "lr", 1.0E-4) + + self.student = van.van_b2(pretrained = False, num_classes = 0) + self.teacher_Features = van.van_b2(pretrained = False, num_classes = 0) + + self.student = RARP_NVB_DINO_Wrapper( + self.student, + RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2) + ) + + self.teacher_Features = RARP_NVB_DINO_Wrapper( + self.teacher_Features, + RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2) + ) + + if len(pre_train_pth) > 0: + self.student.backbone.load_state_dict(torch.load(pre_train_pth)) + + self.teacher_Features.load_state_dict(self.student.state_dict()) + for parms in self.teacher_Features.parameters(): + parms.requires_grad = False + + self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher) + self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student) + + self.decoder = DynamicDecoder(input_channels=1024) + + +class RARP_MAE(L.LightningModule): + def __init__( + self, + backbone:str, + mask_ratio:float = 0.75, + patch_size: int = 16, + #img_size: int = 224, + lr: float = 1e-4, + hiden_channels = [512, 256, 128, 64, 32], + img_mean_std = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]], + activation_fn = torch.nn.ReLU(True) + ): + super().__init__() + + self.save_hyperparameters(ignore=["activation_fn", "img_mean_std"]) + + + + if img_mean_std is not None: + mean, std = img_mean_std + self.register_buffer("mean_IMG", torch.tensor(mean).view(3, 1, 1)) + self.register_buffer("std_IMG", torch.tensor(std).view(3, 1, 1)) + else: + self.mean_IMG = None + self.std_IMG = None + + self.loss_fn = torch.nn.L1Loss() + + match backbone: + case "resnet": + model = torchvision.models.resnet18() + model.fc = torch.nn.Identity() + self.encoder = model + self.encoder_out_dim = 512 + case "van": + model = van.van_b1() + model.head = torch.nn.Identity() + self.encoder = model + self.encoder_out_dim = 512 + case "levit": + self.encoder = timm.create_model("levit_192.fb_dist_in1k", pretrained=False, num_classes=0) + self.encoder_out_dim = 384 + hiden_channels = [384, 256, 128, 64, 32] + case "van_l1_loss": + model = van.van_b1() + model.head = torch.nn.Identity() + self.encoder = model + self.encoder_out_dim = 512 + self.loss_fn = torch.nn.L1Loss() + case "van_2": + model = van.van_b2(num_classes=0) + self.encoder = model + self.encoder_out_dim = 512 + + + in_channel = self.encoder_out_dim + layers = [ + torch.nn.Linear(in_channel, hiden_channels[0]*7*7), + torch.nn.Unflatten(1, (hiden_channels[0], 7, 7)), + ] + + for out_channel in hiden_channels[1:]: + layers.append(torch.nn.ConvTranspose2d(in_channel, out_channel, kernel_size=3, stride=2, padding=1, output_padding=1)) + layers.append(torch.nn.BatchNorm2d(out_channel)) + layers.append(activation_fn) + in_channel = out_channel + + #Last layer + layers.append(torch.nn.ConvTranspose2d(in_channel, 3, kernel_size=3, stride=2, padding=1, output_padding=1)) + + self.decoder = torch.nn.Sequential(*layers) + + def _denormalize(self, tensor:torch.Tensor): + return tensor * self.std_IMG + self.mean_IMG + + def _mask_patches(self, imgs): + B, C, H, W = imgs.shape + ph, pw = self.hparams.patch_size, self.hparams.patch_size + gh, gw = H // ph, W // pw + + mask:torch.Tensor = (torch.rand(B, gh * gw, device=imgs.device) >= self.hparams.mask_ratio) + mask = mask.reshape(B, 1, gh, gw) # [B,1,gh,gw] + + # expand mask to full image + mask = mask.repeat_interleave(ph, dim=2) # [B,1,gh*ph,gw] + mask = mask.repeat_interleave(pw, dim=3) # [B,1,gh*ph,gw*pw] == [B,1,H,W] + + mask = mask.expand(-1, C, -1, -1) + + imgs_masked = imgs * mask + return imgs_masked + + def forward(self, data, val_step=False): + if not val_step: + data = [d.float() for d in data] #[0] original Img, [1] augmented Img + else: + data = [data.float(), data.float()] + + imgs_masked = self._mask_patches(data[1]) + latent = self.encoder(imgs_masked) # [B, 512] + reconstructed_image = self.decoder(latent) # [B, 3, 224, 224] + + return reconstructed_image, data[0] + + def _shared_step(self, batch, val_step=False): + img, _ = batch + res, true_img = self(img, val_step) + + loss_MSE = self.loss_fn(res, true_img) + + return loss_MSE, res, true_img + + def training_step(self, batch, batch_idx): + loss, img, _ = self._shared_step(batch, False) + + self.log("train_loss", loss, on_epoch=True, sync_dist=True) + + if batch_idx % 500 == 0 and self.mean_IMG is not None and self.std_IMG is not None: + imgReconstruction = torch.clip(self._denormalize(img), 0, 1) + grid = torchvision.utils.make_grid(imgReconstruction) + self.logger.experiment.add_image('reconstructed_images', grid, self.global_step) + + return loss + + def validation_step(self, batch, batch_idx): + loss, img, target_img = self._shared_step(batch, True) + + self.log("val_loss", loss, on_epoch=True, prog_bar=True) + + + if self.mean_IMG is not None and self.std_IMG is not None: + img = torch.clip(self._denormalize(img), 0, 1) + target_img = torch.clip(self._denormalize(target_img), 0, 1) + + ssim = piq.ssim(img, target_img, data_range=1.0) + psnr = piq.psnr(img, target_img, data_range=1.0) + + self.log("val_ssim", ssim, on_epoch=True) + self.log("val_psnr", psnr, on_epoch=True) + + if batch_idx % 100 == 0: + grid = torchvision.utils.make_grid(img) + self.logger.experiment.add_image('val_reconstructed_images', grid, self.global_step) + + + def on_after_backward(self): + total_norm = 0.0 + for p in self.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** 0.5 + + self.log("grad_norm", total_norm) + + if total_norm < 1e-8: + self.log("grad_warning", "Vanishing gradient suspected!") + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr) + + return [optimizer] + +class RARP_NVB_DINO_MultiTask_LeViT(RARP_NVB_DINO_MultiTask): + def _hook_fn_Student(self, module, input, output): + B, N, C = output.shape + H = W = int (N**0.5) + + self.last_conv_output_S = output.transpose(1, 2) + self.last_conv_output_S = self.last_conv_output_S.contiguous().view(B, C, H, W) + + def _hook_fn_Teacher(self, module, input, output): + B, N, C = output.shape + H = W = int (N**0.5) + + self.last_conv_output_T = output.transpose(1, 2) + self.last_conv_output_T = self.last_conv_output_T.contiguous().view(B, C, H, W) + + def __init__(self, TypeLoss=TypeLossFunction.CrossEntropy, momentum_teacher = 0.9995, lr = 0.0001, L1 = None, L2 = 0, std = None, mean = None, SoftAdptAlgo = 0, SoftAdptBeta = 0.1, Teacher_T = 0.04, Student_T = 0.1, intermittent = False): + super().__init__(TypeLoss, momentum_teacher, lr, L1, L2, std, mean, SoftAdptAlgo, SoftAdptBeta, Teacher_T, Student_T, intermittent) + + self.in_dim = 768 + self.out_dim = 2048 + + self.student = timm.create_model("levit_384.fb_dist_in1k", pretrained=True, num_classes=0) + self.teacher_Features = timm.create_model("levit_384.fb_dist_in1k", pretrained=True, num_classes=0) + + self.decoder = DynamicDecoder(input_channels = 1024) + + self.student = RARP_NVB_DINO_Wrapper( + self.student, + RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=3) + ) + + self.teacher_Features = RARP_NVB_DINO_Wrapper( + self.teacher_Features, + RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=3) + ) + + self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, Teacher_T, Student_T, momentum_teacher) + + self.teacher_Features.load_state_dict(self.student.state_dict()) + for parms in self.teacher_Features.parameters(): + parms.requires_grad = False + + self.teacher_Features.backbone.stages[-2].register_forward_hook(self._hook_fn_Teacher) + self.student.backbone.stages[-2].register_forward_hook(self._hook_fn_Student) + + self.clasiffier = RARP_NVB_Classification_Head(1024, 1, [128, 8], torch.nn.SiLU(True)) + class RARP_NVB_DINO_MultiTask_Unet(RARP_NVB_DINO_MultiTask): def _encoder_hool_fn(self, module, input, output): self.feature_maps.append(output) - def _register_encoder_hooks(self): - for layer in self.list_blocks: + def _register_encoder_hooks(self, block_list:list): + for layer in block_list: self.hooks.append(layer.register_forward_hook(self._encoder_hool_fn)) def __init__( @@ -2110,7 +2703,15 @@ self.student.backbone.block4[-2], ] - self._register_encoder_hooks() + #self.list_blocks_T = [ + # self.teacher_Features.backbone.block1[-1], + # self.teacher_Features.backbone.block2[-1], + # self.teacher_Features.backbone.block3[-1], + # self.teacher_Features.backbone.block4[-2], + #] + + self._register_encoder_hooks(self.list_blocks) + #self._register_encoder_hooks(self.list_blocks_T) self.decoder = DecoderUnet(1024) @@ -2126,22 +2727,26 @@ TeacherDino = self.teacher_Features(dataTeacher) Student = self.student(dataStudent) + + _temp = [] + num_blocks = len(self.list_blocks) + NumChunks = len(dataStudent) + for i in range(num_blocks): + if not val_step: + S_GlogalViews = torch.cat(self.feature_maps[i + num_blocks].chunk(NumChunks)[1:3], dim=0) + else: + S_GlogalViews = self.feature_maps[i + num_blocks] + + _temp.append(self.feature_maps[i] * S_GlogalViews) + + self.feature_maps = _temp if not val_step: - NumChunks = len(dataStudent) S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3] self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0) cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1) - - if not val_step: - NumChunks = len(dataStudent) - temp = [] - for f_maps in self.feature_maps: - temp.append(torch.cat(f_maps.chunk(NumChunks)[1:3], dim=0)) - self.feature_maps = temp - del temp - + self.feature_maps.append(cat_features) reconstructed_image = self.decoder(self.feature_maps)