diff --git a/EfficientViT/GSViT.py b/EfficientViT/GSViT.py new file mode 100644 index 0000000..766b440 --- /dev/null +++ b/EfficientViT/GSViT.py @@ -0,0 +1,32 @@ +from .classification.model.build import EfficientViT_M5 +import torch +import torch.nn as nn + +class EfficientViT_GSViT(nn.Module): + def __init__(self, pre_trained: str = "EfficientViT_GSViT.pth", force_fp32=False): + super().__init__() + + ## + # + # A implementation base on the original papaer and repo : https://github.com/SamuelSchmidgall/GSViT.git + # + ## + + self.force_fp32 = force_fp32 + self.evit = EfficientViT_M5(pretrained="efficientvit_m5") + self.evit = nn.Sequential(*list(self.evit.children())[:-1]) + + if pre_trained: + ckpt = torch.load(pre_trained, map_location="cpu") + if isinstance(ckpt, dict) and "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + self.evit.load_state_dict(ckpt, strict=False) + + def forward(self, x): + if self.force_fp32 and x.is_cuda: + with torch.cuda.amp.autocast(enabled=False): + x = self.evit(x.float()) + else: + x = self.evit(x) + out = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) + return out \ No newline at end of file diff --git a/EfficientViT/GSViT_RARP.py b/EfficientViT/GSViT_RARP.py new file mode 100644 index 0000000..460ffd1 --- /dev/null +++ b/EfficientViT/GSViT_RARP.py @@ -0,0 +1,48 @@ +from .classification.model.build import EfficientViT_M5 +import torch +import torch.nn as nn +from collections import OrderedDict + +class EfficientViT_GSViT(nn.Module): + def _remap_sequential_keys(self, sd): + out = OrderedDict() + for k, v in sd.items(): + if k.startswith("module."): + k = k[len("module."):] + parts = k.split(".") + if parts[0].isdigit(): + idx = int(parts[0]) + if idx not in self.index_to_name: + continue # drop anything you don't want + parts[0] = self.index_to_name[idx] + k = ".".join(parts) + out[k] = v + return out + + def __init__(self, pre_trained: str = "EfficientViT_GSViT.pth", force_oo_model=True): + super().__init__() + + ## + # + # A implementation base on the original papaer and repo : https://github.com/SamuelSchmidgall/GSViT.git + # + ## + + self._force = force_oo_model + self.evit = EfficientViT_M5(pretrained="efficientvit_m5") + self.evit.head = nn.Identity() + + if pre_trained: + ckpt = torch.load(pre_trained, map_location="cpu") + if isinstance(ckpt, dict) and "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + if force_oo_model: + self.index_to_name = {k:named_C[0] for k, named_C in enumerate(self.evit.named_children())} + ckpt = self._remap_sequential_keys(ckpt) + + self.evit.load_state_dict(ckpt, strict=False) + + def forward(self, x): + out = self.evit(x) + #out = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) + return out \ No newline at end of file diff --git a/EfficientViT/__init__.py b/EfficientViT/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/EfficientViT/__init__.py diff --git a/EfficientViT/classification/model/__init__.py b/EfficientViT/classification/model/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/EfficientViT/classification/model/__init__.py diff --git a/EfficientViT/classification/model/build.py b/EfficientViT/classification/model/build.py new file mode 100644 index 0000000..fe96e3c --- /dev/null +++ b/EfficientViT/classification/model/build.py @@ -0,0 +1,184 @@ +''' +Build the EfficientViT model family +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from .efficientvit import EfficientViT +#from timm.models.registry import register_model +from timm.models import register_model + +EfficientViT_m0 = { + 'img_size': 224, + 'patch_size': 16, + 'embed_dim': [64, 128, 192], + 'depth': [1, 2, 3], + 'num_heads': [4, 4, 4], + 'window_size': [7, 7, 7], + 'kernels': [5, 5, 5, 5], + } + +EfficientViT_m1 = { + 'img_size': 224, + 'patch_size': 16, + 'embed_dim': [128, 144, 192], + 'depth': [1, 2, 3], + 'num_heads': [2, 3, 3], + 'window_size': [7, 7, 7], + 'kernels': [7, 5, 3, 3], + } + +EfficientViT_m2 = { + 'img_size': 224, + 'patch_size': 16, + 'embed_dim': [128, 192, 224], + 'depth': [1, 2, 3], + 'num_heads': [4, 3, 2], + 'window_size': [7, 7, 7], + 'kernels': [7, 5, 3, 3], + } + +EfficientViT_m3 = { + 'img_size': 224, + 'patch_size': 16, + 'embed_dim': [128, 240, 320], + 'depth': [1, 2, 3], + 'num_heads': [4, 3, 4], + 'window_size': [7, 7, 7], + 'kernels': [5, 5, 5, 5], + } + +EfficientViT_m4 = { + 'img_size': 224, + 'patch_size': 16, + 'embed_dim': [128, 256, 384], + 'depth': [1, 2, 3], + 'num_heads': [4, 4, 4], + 'window_size': [7, 7, 7], + 'kernels': [7, 5, 3, 3], + } + +EfficientViT_m5 = { + 'img_size': 224, + 'patch_size': 16, + 'embed_dim': [192, 288, 384], + 'depth': [1, 3, 4], + 'num_heads': [3, 3, 4], + 'window_size': [7, 7, 7], + 'kernels': [7, 5, 3, 3], + } + + +@register_model +def EfficientViT_M0(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m0): + model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) + if pretrained: + pretrained = _checkpoint_url_format.format(pretrained) + checkpoint = torch.hub.load_state_dict_from_url( + pretrained, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + if fuse: + replace_batchnorm(model) + return model + +@register_model +def EfficientViT_M1(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m1): + model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) + if pretrained: + pretrained = _checkpoint_url_format.format(pretrained) + checkpoint = torch.hub.load_state_dict_from_url( + pretrained, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + if fuse: + replace_batchnorm(model) + return model + +@register_model +def EfficientViT_M2(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m2): + model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) + if pretrained: + pretrained = _checkpoint_url_format.format(pretrained) + checkpoint = torch.hub.load_state_dict_from_url( + pretrained, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + if fuse: + replace_batchnorm(model) + return model + +@register_model +def EfficientViT_M3(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m3): + model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) + if pretrained: + pretrained = _checkpoint_url_format.format(pretrained) + checkpoint = torch.hub.load_state_dict_from_url( + pretrained, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + if fuse: + replace_batchnorm(model) + return model + +@register_model +def EfficientViT_M4(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m4): + model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) + if pretrained: + pretrained = _checkpoint_url_format.format(pretrained) + checkpoint = torch.hub.load_state_dict_from_url( + pretrained, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + if fuse: + replace_batchnorm(model) + return model + +@register_model +def EfficientViT_M5(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m5): + model = EfficientViT(num_classes=num_classes, distillation=distillation, **model_cfg) + if pretrained: + pretrained = _checkpoint_url_format.format(pretrained) + checkpoint = torch.hub.load_state_dict_from_url( + pretrained, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + if fuse: + replace_batchnorm(model) + return model + +def replace_batchnorm(net): + for child_name, child in net.named_children(): + if hasattr(child, 'fuse'): + setattr(net, child_name, child.fuse()) + elif isinstance(child, torch.nn.BatchNorm2d): + setattr(net, child_name, torch.nn.Identity()) + else: + replace_batchnorm(child) + +_checkpoint_url_format = \ + 'https://github.com/xinyuliu-jeffrey/EfficientViT_Model_Zoo/releases/download/v1.0/{}.pth' diff --git a/EfficientViT/classification/model/efficientvit.py b/EfficientViT/classification/model/efficientvit.py new file mode 100644 index 0000000..2cf39c7 --- /dev/null +++ b/EfficientViT/classification/model/efficientvit.py @@ -0,0 +1,356 @@ +# -------------------------------------------------------- +# EfficientViT Model Architecture +# Copyright (c) 2022 Microsoft +# Build the EfficientViT Model +# Written by: Xinyu Liu +# -------------------------------------------------------- +import torch +import itertools + +from timm.models.vision_transformer import trunc_normal_ +from timm.layers import SqueezeExcite + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', torch.nn.BatchNorm2d(b)) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class BN_Linear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + self.add_module('l', torch.nn.Linear(a, b, bias=bias)) + trunc_normal_(self.l.weight, std=std) + if bias: + torch.nn.init.constant_(self.l.bias, 0) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps)**0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class PatchMerging(torch.nn.Module): + def __init__(self, dim, out_dim, input_resolution): + super().__init__() + hid_dim = int(dim * 4) + self.conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0, resolution=input_resolution) + self.act = torch.nn.ReLU() + self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, resolution=input_resolution) + self.se = SqueezeExcite(hid_dim, .25) + self.conv3 = Conv2d_BN(hid_dim, out_dim, 1, 1, 0, resolution=input_resolution // 2) + + def forward(self, x): + x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x)))))) + return x + + +class Residual(torch.nn.Module): + def __init__(self, m, drop=0.): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1, + device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class FFN(torch.nn.Module): + def __init__(self, ed, h, resolution): + super().__init__() + self.pw1 = Conv2d_BN(ed, h, resolution=resolution) + self.act = torch.nn.ReLU() + self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0, resolution=resolution) + + def forward(self, x): + x = self.pw2(self.act(self.pw1(x))) + return x + + +class CascadedGroupAttention(torch.nn.Module): + r""" Cascaded Group Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution, correspond to the window size. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=14, + kernels=[5, 5, 5, 5],): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.d = int(attn_ratio * key_dim) + self.attn_ratio = attn_ratio + + qkvs = [] + dws = [] + for i in range(num_heads): + qkvs.append(Conv2d_BN(dim // (num_heads), self.key_dim * 2 + self.d, resolution=resolution)) + dws.append(Conv2d_BN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim, resolution=resolution)) + self.qkvs = torch.nn.ModuleList(qkvs) + self.dws = torch.nn.ModuleList(dws) + self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN( + self.d * num_heads, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,C,H,W) + B, C, H, W = x.shape + trainingab = self.attention_biases[:, self.attention_bias_idxs] + feats_in = x.chunk(len(self.qkvs), dim=1) + feats_out = [] + feat = feats_in[0] + for i, qkv in enumerate(self.qkvs): + if i > 0: # add the previous output to the input + feat = feat + feats_in[i] + feat = qkv(feat) + q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W + q = self.dws[i](q) + q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N + attn = ( + (q.transpose(-2, -1) @ k) * self.scale + + + (trainingab[i] if self.training else self.ab[i]) + ) + attn = attn.softmax(dim=-1) # BNN + feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW + feats_out.append(feat) + x = self.proj(torch.cat(feats_out, 1)) + return x + + +class LocalWindowAttention(torch.nn.Module): + r""" Local Window Attention. + + Args: + dim (int): Number of input channels. + key_dim (int): The dimension for query and key. + num_heads (int): Number of attention heads. + attn_ratio (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5],): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.resolution = resolution + assert window_resolution > 0, 'window_size must be greater than 0' + self.window_resolution = window_resolution + + window_resolution = min(window_resolution, resolution) + self.attn = CascadedGroupAttention(dim, key_dim, num_heads, + attn_ratio=attn_ratio, + resolution=window_resolution, + kernels=kernels,) + + def forward(self, x): + H = W = self.resolution + B, C, H_, W_ = x.shape + # Only check this for classifcation models + assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_)) + + if H <= self.window_resolution and W <= self.window_resolution: + x = self.attn(x) + else: + x = x.permute(0, 2, 3, 1) + pad_b = (self.window_resolution - H % + self.window_resolution) % self.window_resolution + pad_r = (self.window_resolution - W % + self.window_resolution) % self.window_resolution + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_resolution + nW = pW // self.window_resolution + # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw + x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3).reshape( + B * nH * nW, self.window_resolution, self.window_resolution, C + ).permute(0, 3, 1, 2) + x = self.attn(x) + # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC + x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, + C).transpose(2, 3).reshape(B, pH, pW, C) + if padding: + x = x[:, :H, :W].contiguous() + x = x.permute(0, 3, 1, 2) + return x + + +class EfficientViTBlock(torch.nn.Module): + """ A basic EfficientViT building block. + + Args: + type (str): Type for token mixer. Default: 's' for self-attention. + ed (int): Number of input channels. + kd (int): Dimension for query and key in the token mixer. + nh (int): Number of attention heads. + ar (int): Multiplier for the query dim for value dimension. + resolution (int): Input resolution. + window_resolution (int): Local window resolution. + kernels (List[int]): The kernel size of the dw conv on query. + """ + def __init__(self, type, + ed, kd, nh=8, + ar=4, + resolution=14, + window_resolution=7, + kernels=[5, 5, 5, 5],): + super().__init__() + + self.dw0 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution)) + self.ffn0 = Residual(FFN(ed, int(ed * 2), resolution)) + + if type == 's': + self.mixer = Residual(LocalWindowAttention(ed, kd, nh, attn_ratio=ar, \ + resolution=resolution, window_resolution=window_resolution, kernels=kernels)) + + self.dw1 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution)) + self.ffn1 = Residual(FFN(ed, int(ed * 2), resolution)) + + def forward(self, x): + return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x))))) + + +class EfficientViT(torch.nn.Module): + def __init__(self, img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + stages=['s', 's', 's'], + embed_dim=[64, 128, 192], + key_dim=[16, 16, 16], + depth=[1, 2, 3], + num_heads=[4, 4, 4], + window_size=[7, 7, 7], + kernels=[5, 5, 5, 5], + down_ops=[['subsample', 2], ['subsample', 2], ['']], + distillation=False,): + super().__init__() + + resolution = img_size + # Patch embedding + self.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1, resolution=resolution), torch.nn.ReLU(), + Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1, resolution=resolution // 2), torch.nn.ReLU(), + Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1, resolution=resolution // 4), torch.nn.ReLU(), + Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1, resolution=resolution // 8)) + + resolution = img_size // patch_size + attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))] + self.blocks1 = [] + self.blocks2 = [] + self.blocks3 = [] + + # Build EfficientViT blocks + for i, (stg, ed, kd, dpth, nh, ar, wd, do) in enumerate( + zip(stages, embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)): + for d in range(dpth): + eval('self.blocks' + str(i+1)).append(EfficientViTBlock(stg, ed, kd, nh, ar, resolution, wd, kernels)) + if do[0] == 'subsample': + # Build EfficientViT downsample block + #('Subsample' stride) + blk = eval('self.blocks' + str(i+2)) + resolution_ = (resolution - 1) // do[1] + 1 + blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i], resolution=resolution)), + Residual(FFN(embed_dim[i], int(embed_dim[i] * 2), resolution)),)) + blk.append(PatchMerging(*embed_dim[i:i + 2], resolution)) + resolution = resolution_ + blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1], resolution=resolution)), + Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2), resolution)),)) + self.blocks1 = torch.nn.Sequential(*self.blocks1) + self.blocks2 = torch.nn.Sequential(*self.blocks2) + self.blocks3 = torch.nn.Sequential(*self.blocks3) + + # Classification head + self.head = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + self.distillation = distillation + if distillation: + self.head_dist = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + x = self.blocks1(x) + x = self.blocks2(x) + x = self.blocks3(x) + x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x diff --git a/Loaders.py b/Loaders.py index 6e06463..ebd0592 100644 --- a/Loaders.py +++ b/Loaders.py @@ -1468,7 +1468,7 @@ if self.key_frame_transform is not None: key_frame = self.key_frame_transform(key_frame) else: - path_cache = Path(item["path"]).resolve().parent / "chache" + path_cache = Path(item["path"]).resolve().parent / "cache" file_name = f"F{self.FOLD}_{item['case']}.npz" cached_features = np.load((path_cache / file_name)) @@ -1504,13 +1504,18 @@ winds = torch.stack(window_tensors, dim=0) masks = torch.stack(window_masks, dim=0) + meta = { + "case_id": item["case"], + #"start": wind_idx + } + if not self.load_key_frames: - return winds, label, masks + return winds, label, masks, meta else: if not self.load_key_frame_features_from_cache: - return winds, label, masks, key_frame + return winds, label, masks, key_frame, meta else: - return winds, label, masks, key_frame, soft_labels #key_frame = image features + return winds, label, masks, key_frame, soft_labels, meta #key_frame = image features class RARP_Windowed_Video_frames_Dataset(Dataset): def _sliding_windows(self, total_frames:int, conver_last=True): diff --git a/MIL_video_rarp.py b/MIL_video_rarp.py new file mode 100644 index 0000000..26e8062 --- /dev/null +++ b/MIL_video_rarp.py @@ -0,0 +1,445 @@ +import pandas as pd +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +import torchvision +import Loaders +import torchmetrics +import matplotlib.pyplot as plt +import lightning as L +from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger +from lightning.pytorch import seed_everything +import lightning.pytorch.callbacks as callbk +import Models_video as M +from pathlib import Path +import numpy as np +from tqdm import tqdm +import argparse + + +torch.backends.cuda.matmul.allow_tf32 = True +torch.set_float32_matmul_precision('high') +torch.backends.cudnn.deterministic = True + + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + seed_everything(seed, workers=True) + torch.backends.cudnn.deterministic = True + +def rolling_mean_std(a, w): + csum = np.cumsum(a, axis=0) + csum = np.pad(csum, ((1,0),(0,0)), mode="constant") + win_sum = csum[w:] - csum[:-w] + mean = win_sum / float(w) + sq = a**2 + csum_sq = np.cumsum(sq, axis=0) + csum_sq = np.pad(csum_sq, ((1,0),(0,0)), mode="constant") + win_sum_sq = csum_sq[w:] - csum_sq[:-w] + var = (win_sum_sq / float(w)) - mean**2 + std = np.sqrt(np.maximum(var, 1e-12)) + return mean, std + +def plot_tensor_analysis(x, fps=30, win=None, out_prefix="tensor_analysis"): + """ + Visualize a tensor of shape [T, F] with: + 1) Time series per feature (raw + rolling mean ± std) + 2) Heatmap overview (per-feature normalized to [0,1]) + 3) Distribution boxplots per feature + + Args: + x (torch.Tensor): Input tensor of shape [T, F]. + fps (int): Frames per second (for x-axis in seconds). + win (int or None): Rolling window size in frames. Default = fps. + out_prefix (str): Prefix for saved file names. + """ + # --- check input --- + if not torch.is_tensor(x): + raise ValueError("x must be a torch.Tensor") + if x.ndim != 2: + raise ValueError("x must have shape [T, F]") + + T, F = x.shape + time_idx = np.arange(T) + time_sec = time_idx / float(fps) + + arr = x.detach().cpu().numpy() + + # --- rolling mean/std --- + if win is None: + win = max(3, fps) # default = ~1 second + half = win // 2 + + roll_mean, roll_std = rolling_mean_std(arr, win) + roll_t = time_sec[half:half+len(roll_mean)] + + # ---------- 1) Time series ---------- + fig_ts, axes = plt.subplots(F, 1, figsize=(10, 2.5*F), sharex=True) + if F == 1: + axes = [axes] + + for f in range(F): + ax = axes[f] + ax.plot(time_sec, arr[:, f], alpha=0.35, linewidth=1.0, label=f'Feature {f}') + ax.plot(roll_t, roll_mean[:, f], linewidth=2.0, label=f'Rolling mean (w={win})') + ax.fill_between(roll_t, + roll_mean[:, f] - roll_std[:, f], + roll_mean[:, f] + roll_std[:, f], + alpha=0.2, label='±1 std (rolling)') + ax.set_ylabel(f'Feature {f}') + ax.grid(True, linestyle='--', alpha=0.3) + axes[-1].set_xlabel('Time (s)') + axes[0].legend(loc='upper right') + fig_ts.suptitle('Per-feature time series with rolling mean ± std', y=1.02) + fig_ts.tight_layout() + fig_ts.savefig(f"output/{out_prefix}_time_series.png", dpi=200) + + # ---------- 2) Heatmap ---------- + fig_hm, ax = plt.subplots(figsize=(10, 2.8)) + arr_min = arr.min(axis=0, keepdims=True) + arr_max = arr.max(axis=0, keepdims=True) + arr_norm = (arr - arr_min) / (arr_max - arr_min + 1e-12) + + im = ax.imshow(arr_norm.T, aspect='auto', interpolation='nearest', + extent=[time_sec[0], time_sec[-1], F-0.5, -0.5]) + ax.set_yticks(np.arange(F)) + ax.set_yticklabels([f'Feat {f}' for f in range(F)]) + ax.set_xlabel('Time (s)') + ax.set_title('Heatmap (per-feature normalized)') + fig_hm.colorbar(im, ax=ax, fraction=0.025, pad=0.02) + fig_hm.tight_layout() + fig_hm.savefig(f"output/{out_prefix}_heatmap.png", dpi=200) + + # ---------- 3) Boxplots ---------- + fig_box, ax = plt.subplots(figsize=(7, 3.5)) + ax.boxplot([arr[:, f] for f in range(F)], showmeans=True) + ax.set_xticklabels([f'Feat {f}' for f in range(F)]) + ax.set_ylabel('Value') + ax.set_title('Distribution across time (boxplot per feature)') + ax.grid(True, axis='y', linestyle='--', alpha=0.3) + fig_box.tight_layout() + fig_box.savefig(f"output/{out_prefix}_boxplots.png", dpi=200) + + print(f"Saved: {out_prefix}_time_series.png, {out_prefix}_heatmap.png, {out_prefix}_boxplots.png") + +def Calc_Eval_table( + TrainModel, + TestDataLoadre:DataLoader, + Youden=False, + modelName="", +): + TrainModel.to(device) + TrainModel.eval() + + Predictions = [] + Labels = [] + + with torch.no_grad(): + for data, label, mask, key_frame in tqdm(iter(TestDataLoadre)): + + data = data.to(device, dtype=torch.float32) + key_frame = key_frame.to(device, dtype=torch.float32) + mask = mask.to(device) + label = label.to(device) + + #pred, *_ = TrainModel(data) + pred, _ = TrainModel(data, key_frame, mask) + pred = pred.flatten() + + Predictions.append(torch.sigmoid(pred)) + Labels.append(label) + + Predictions = torch.cat(Predictions) + Labels = torch.cat(Labels).int() + + #print(Predictions, Labels) + + acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels) + precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels) + recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels) + auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels) + f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels) + specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels) + + table = [ + ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""] + ] + + if Youden: + for i in range(2): + aucCurve = torchmetrics.ROC("binary").to(device) + fpr, tpr, thhols = aucCurve(Predictions, Labels) + index = torch.argmax(tpr - fpr) + th2 = (recall + specificty - 1).item() + th2 = 0.5 if th2 <= 0 else th2 + th1 = thhols[index].item() if i == 0 else th2 + accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels) + precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels) + recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels) + specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels) + f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels) + #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device) + #cm2.update(Predictions, Labels) + #_, ax = cm2.plot() + #ax.set_title(f"NVB Classifier (th={th1:.4f})") + table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName]) + + + return table + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'", required=True) + parser.add_argument("--Fold", type=int, default=0) + parser.add_argument("-lv","--Log_version", type=int, default=None) + parser.add_argument("--Workers", type=int, default=0) + parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp") + parser.add_argument("--CNN_name", type=str, default=None, ) + parser.add_argument("--Temp_Head", type=str, default=None, ) + parser.add_argument("-me", "--maxEpochs", type=int, default=None) + parser.add_argument("-b", "--Batch_size", type=int, default=8) + parser.add_argument("--GPU", type=int, default=0) + parser.add_argument("--pre_train", type=int, default=0) + parser.add_argument("-k", "--k_windows", type=int, default=1) + parser.add_argument("--Window_Size", type=int, default=64) + parser.add_argument("--Num_Window", type=int, default=8) + parser.add_argument("--cached_features", type=bool, default=False) + + args = parser.parse_args() + + setup_seed(2023) + device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu") + + df = pd.read_csv("../dataset/Dataset_RARP_video/dataset_videos_folds.csv") + + FOLD = args.Fold + WORKERS = args.Workers + BATCH_SIZE = args.Batch_size + MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs + PRE_TRAIN = args.pre_train != 0 + K_WIN = args.k_windows + KEY_FRAME = True + WIN_LENGTH = args.Window_Size + NUM_WIN = args.Num_Window + CACHED_FEATURES = args.cached_features + + Mean = [0.485, 0.456, 0.406] + Std = [0.229, 0.224, 0.225] + + print(f"Fold_{FOLD}") + + ckpt_paths = [ + Path("./log_XAblation_van_DINO/lightning_logs/version_0/checkpoints/RARP-epoch=20.ckpt"), + Path("./log_XAblation_van_DINO/lightning_logs/version_1/checkpoints/RARP-epoch=32.ckpt"), + Path("./log_XAblation_van_DINO/lightning_logs/version_2/checkpoints/RARP-epoch=28.ckpt"), + Path("./log_XAblation_van_DINO/lightning_logs/version_3/checkpoints/RARP-epoch=27.ckpt"), + Path("./log_XAblation_van_DINO/lightning_logs/version_4/checkpoints/RARP-epoch=30.ckpt"), + ] + + train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records") + val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records") + test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records") + + traintransformT2 = torch.nn.Sequential( + transforms.CenterCrop(300), + transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=0), + transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), + transforms.GaussianBlur(kernel_size=3), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ).to(device) + + traintransform_frame = torch.nn.Sequential( + transforms.RandomApply([ + transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)), + transforms.RandomErasing(1.0, value="random") + ], 0.3) #small noise + ).to(device) + + testVal_transform = torch.nn.Sequential( + transforms.CenterCrop(300), + transforms.Resize((224, 224), antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ).to(device) + + key_frame_transform = torch.nn.Sequential( + transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.Normalize([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373]) + ).to(device) + + train_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset( + train_set, + train_mode=True, + num_windows=NUM_WIN, + window_length=WIN_LENGTH, + transform=traintransformT2, + transform_frame=traintransform_frame, + key_frames=KEY_FRAME, + key_frame_transform=key_frame_transform, + load_key_frame_cache=CACHED_FEATURES, + Fold_index=FOLD + ) + val_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset( + val_set, + train_mode=False, + num_windows=NUM_WIN, + window_length=WIN_LENGTH, + transform=testVal_transform, + key_frames=KEY_FRAME, + key_frame_transform=key_frame_transform, + load_key_frame_cache=CACHED_FEATURES, + Fold_index=FOLD + ) + test_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset( + test_set, + train_mode=False, + num_windows=NUM_WIN, + window_length=WIN_LENGTH, + transform=testVal_transform, + key_frames=KEY_FRAME, + key_frame_transform=key_frame_transform, + load_key_frame_cache=CACHED_FEATURES, + Fold_index=FOLD + ) + + train_loader = DataLoader( + train_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + prefetch_factor=1 if WORKERS>0 else None, + #pin_memory=True, + num_workers=WORKERS, + persistent_workers=WORKERS>0 + ) + + val_loader = DataLoader( + val_dataset, + batch_size=BATCH_SIZE, + shuffle=False, + #pin_memory=True, + num_workers=WORKERS, + persistent_workers=WORKERS>0 + ) + + test_loader = DataLoader( + test_dataset, + batch_size=BATCH_SIZE, + shuffle=False, + #pin_memory=True, + num_workers=WORKERS, + persistent_workers=WORKERS>0 + ) + + LogFileName = f"{args.Log_Name}" + + checkPtCallback = [ + callbk.ModelCheckpoint(monitor='val_wind_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'), + #callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5) + ] + + trainer = L.Trainer( + precision="32-true" if args.CNN_name == "gsvit" else "16-mixed", + deterministic=True, + accelerator="gpu", + devices=[args.GPU], + #devices=[0, 1], strategy="ddp", + logger=TensorBoardLogger(save_dir=f"./{LogFileName}") if args.Phase == "train" else CSVLogger(save_dir=f"./{LogFileName}/Test", version=args.Log_version), + log_every_n_steps=5, + callbacks=checkPtCallback, + max_epochs=MAX_EPOCHS + ) + + match(args.Phase): + case "cache_key_frame": + from Models import RARP_NVB_DINO_MultiTask + + print (f"Load Export model for the FOLD #{FOLD}") + Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(ckpt_paths[FOLD], map_location=device) + Hybrid_TS.eval() + + namelist = ["TRAIN", "VAL", "TEST"] + + for _i, _s in enumerate([train_set, val_set, test_set]): + print (f"[{namelist[_i]} Set] of FOLD # {FOLD}") + key_frame_dataset = Loaders.RARP_Windowed_Video_MIL_Dataset( + _s, + key_frames=True, + key_frame_transform=key_frame_transform, + key_frame_only=True, + ) + key_frameloader = DataLoader( + key_frame_dataset, + batch_size=BATCH_SIZE, + shuffle=False, + pin_memory=True, + num_workers=WORKERS, + persistent_workers=WORKERS>0 + ) + + print (f"[SAVE] caching Image features and Soft lables from Expert Model in FOLD #{FOLD}") + with torch.no_grad(): + for img, case_id in tqdm(iter(key_frameloader)): + B = img.shape[0] + img = img.to(device, dtype=torch.float) + + Soft_label, _, _ = Hybrid_TS(img) + Img_features = torch.cat((Hybrid_TS.last_conv_output_S, Hybrid_TS.last_conv_output_T), dim=1) + Img_features = torch.nn.functional.adaptive_avg_pool2d(Img_features, (1,1)).flatten(1) + + for i in range(B): + parent_path = next((r for r in _s if r.get("case") == case_id[i]), None) + parent_path = Path(parent_path["path"]).resolve().parent + parent_path = parent_path / "chache" + parent_path.mkdir(exist_ok=True) + np.savez((parent_path / f"F{FOLD}_{case_id[i]}.npz"), soft_label=Soft_label[i].cpu().numpy(), img_features=Img_features[i].cpu().numpy()) + + print (f"[DONE] FOLD #{FOLD}") + + case "train": + Model = M.RARP_NVB_Multi_MOD_MIL( + num_classes=1, + temporal=args.Temp_Head, + cnn_name=args.CNN_name, + dropout=0.3, + lr=1e-4, #3e-4, + weight_decay=0.1, #0.05 + epochs=MAX_EPOCHS, + pre_train=PRE_TRAIN, + Hybrid_TS_weights=str(ckpt_paths[FOLD].resolve()) if not CACHED_FEATURES else None + ) + + print(f"Model Used: {type(Model).__name__}") + print("Train Phase") + trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader) + trainer.test(Model, dataloaders=test_loader, ckpt_path="best") + case "eval_all": + print("Evaluation Phase") + rows = [] + pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/") + for ckpFile in sorted(pathCkptFile.glob("*.ckpt")): + print(ckpFile.name) + #trainer.test(Model, dataloaders=test_loader, ckpt_path=ckpFile) + #Model = M.RARP_NVB_DINO_MultiTask_A5_Video.load_from_checkpoint(ckpFile) + + hp_fiel = pathCkptFile.parent / "hparams.yaml" + Model = M.RARP_NVB_Multi_MOD_MIL_TESTMode.load_from_checkpoint(ckpFile, map_location=device, hparams_file=hp_fiel) + trainer.test(Model, dataloaders=test_loader) + + #temp = Calc_Eval_table(Model, test_loader, True, ckpFile.name) + temp = Model._test_results + rows += temp + + df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"]) + #df.style.highlight_max(color="red", axis=0) + output_file = Path(f"./{LogFileName}/output.xlsx") + if not output_file.exists(): + df.to_excel(output_file, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}") + else: + with pd.ExcelWriter(output_file, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer: + df.to_excel(writer, sheet_name=f"Fold_{FOLD}_ver_{args.Log_version}") + print("[END] File saved ... ") \ No newline at end of file diff --git a/Models_video.py b/Models_video.py index 58140f9..3e78bf4 100644 --- a/Models_video.py +++ b/Models_video.py @@ -5,7 +5,6 @@ import torch.utils.checkpoint as torch_ckp import torchvision import torchmetrics -import torchmetrics.classification import lightning as L import van import numpy as np @@ -13,6 +12,8 @@ from collections import defaultdict from Models import RARP_NVB_DINO_MultiTask from pathlib import Path +import pandas as pd +from EfficientViT.GSViT import EfficientViT_GSViT @@ -238,6 +239,10 @@ feature_dim = backbone.head.in_features backbone.head = nn.Identity() self.layers_to_unfreeze = ["block3", "block4"] + case "gsvit": + backbone = EfficientViT_GSViT(str(Path("./EfficientViT/EfficientViT_GSViT.pth").resolve())) + feature_dim = 384 + self.layers_to_unfreeze = [] case _: raise NotImplementedError(f"CNN name = '{cnn_name}' is not implemented yet") @@ -530,22 +535,22 @@ self.T = t def forward(self, z_video, z_key): - p_video = torch.sigmoid(z_video / self.T) - p_key = torch.sigmoid(z_key / self.T).detach() - loss = nn.functional.binary_cross_entropy(p_video, p_key) + z_video = z_video.float() + z_key = z_key.float() + + # teacher probabilities (no gradient) + with torch.no_grad(): + p_key = torch.sigmoid(z_key / self.T) # [B] in [0,1] + + # stable BCE with logits; equivalent to BCE(sigmoid(z_video/T), p_key) + loss = nn.functional.binary_cross_entropy_with_logits(z_video / self.T, p_key) return self.Lambda * (self.T ** 2) * loss class RARP_NVB_Multi_MOD(RARP_NVB_Wind_video): def _unfreeze_last_n_layers(self, model:nn.Module): - # collect layer names - #all_layers = [name for name, _ in model.named_children()] - # last n layers - #layers_to_unfreeze = all_layers[-self.num_layers_cnn:] - # freeze everything - for p in model.parameters(): p.requires_grad = False @@ -572,12 +577,15 @@ ): super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn) - assert len(Hybrid_TS_weights) > 0, "The Key frame model require pre-trained weigths" - - self.Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(Path(Hybrid_TS_weights), map_location=self.device) - self.Hybrid_TS.eval() - for p in self.Hybrid_TS.parameters(): - p.requires_grad = False + if Hybrid_TS_weights is not None: + assert len(Hybrid_TS_weights) > 0, "The Key frame model require pre-trained weigths" + + self.Hybrid_TS = RARP_NVB_DINO_MultiTask.load_from_checkpoint(Path(Hybrid_TS_weights), map_location=self.device) + self.Hybrid_TS.eval() + for p in self.Hybrid_TS.parameters(): + p.requires_grad = False + else: + self.Hybrid_TS = None self._unfreeze_last_n_layers(self.cnn) @@ -696,7 +704,310 @@ ]) return optimizer + +class WindowAttentionMIL(nn.Module): + def __init__(self, dim, att_dim=128): + super().__init__() + self.att_v = nn.Linear(dim, att_dim) + self.att_u = nn.Linear(att_dim, 1) -class RARP_NVB_Multi_MOD_A1(RARP_NVB_Multi_MOD): - def __init__(self, num_classes, temporal="gru", cnn_name="resnet18", dropout=0.2, pre_train=False, lr=0.0003, weight_decay=0.05, epochs=50, warmup_epochs=3, label_smoothing=0, frizze_cnn=True, Hybrid_TS_weights = ""): - super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn, Hybrid_TS_weights) \ No newline at end of file + def forward(self, H): + A = torch.tanh(self.att_v(H)) + logits = self.att_u(A) + + alpha = torch.softmax(logits, dim=1) + v = (alpha * H).sum(dim=1) + return v, alpha + +class AttentionEntropyRangeLoss(nn.Module): + def __init__(self, target_entropy:float, eps:float = 1e-8): + super().__init__() + + self.H_0 = target_entropy + self.eps = eps + + def forward(self, alpha:torch.Tensor)->torch.Tensor: + if alpha.dim() == 3: + alpha = alpha.squeeze(-1) + + alpha.clamp(min=self.eps) + + H = -(alpha * alpha.log()).sum(dim=1) + W = alpha.shape[1] + + H_norm = H / torch.log(torch.tensor(W, device=alpha.device)) + + loss = (H_norm - self.H_0) ** 2 + + return loss.mean() + +class RARP_NVB_Multi_MOD_MIL(RARP_NVB_Multi_MOD): + def __init__( + self, + num_classes, + temporal="gru", + cnn_name="resnet18", + dropout=0.2, + pre_train=False, + lr=0.0003, + weight_decay=0.05, + epochs=50, + warmup_epochs=3, + label_smoothing=0, + frizze_cnn=True, + Hybrid_TS_weights = "", + attn_reg_weight:float=0.02, + attn_entropy_target:float=0.40, + attn_reg_warmup_epochs:int=5, + FOLD:int=None + ): + super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn, Hybrid_TS_weights) + + self.win_mil_att = WindowAttentionMIL(self.mid_dim, att_dim=128) + self.win_pool = nn.Sequential( + nn.LayerNorm(self.mid_dim), + nn.Dropout(dropout), + ) + + self.attn_reg_weight = attn_reg_weight + self.attn_reg_warmup_epochs = attn_reg_warmup_epochs + self.attn_loss = AttentionEntropyRangeLoss(attn_entropy_target) + + def _attn_reg_lambda(self) -> float: + # Linear warmup from 0 to attn_reg_weight over attn_reg_warmup_epochs + if self.attn_reg_warmup_epochs <= 0: + return float(self.attn_reg_weight) + t = min(1.0, self.current_epoch / self.attn_reg_warmup_epochs) + return float(self.attn_reg_weight * t) + + def _frame_wise_pass(self, data:torch.Tensor, key_frame:torch.Tensor, mask:torch.Tensor): + + B_N, L, C, H, W = data.shape + + data = data.view(B_N * L, C, H, W) # Flaten the video (Big Batch) + data = data.contiguous() + + B = key_frame.shape[0] + n_win = B_N // B + + cnn_features = self.cnn(data) # [B_N*L, F] + cnn_features = cnn_features.view(B_N, L, -1) # [B_N, L, F] + + time_features = self.temporal_head(cnn_features, mask) + + # --- FiLM --- + h_mid = self.proy_video(time_features) + + if self.Hybrid_TS is not None: + with torch.no_grad(): + pred_key_frame, _, _ = self.Hybrid_TS(key_frame) + img_features = torch.cat((self.Hybrid_TS.last_conv_output_S, self.Hybrid_TS.last_conv_output_T), dim=1) + img_features = nn.functional.adaptive_avg_pool2d(img_features, (1,1)).flatten(1) + else: + pred_key_frame = None + img_features = key_frame + + img_features = ( + img_features + .unsqueeze(1) # [B, 1, F_img] + .expand(B, n_win, img_features.size(1)) # [B, n_win, F_img] + .contiguous() + .view(B_N, -1) # [B_N, F_img] + ) + + h_film = self.film(h_mid, img_features) + + # --- Mask Pooling --- + video_features = self._mask_pooling(h_film, mask) + + video_features = self.pool(video_features) + + return video_features, pred_key_frame + + def forward(self, data:torch.Tensor, key_frame:torch.Tensor, mask:torch.Tensor): + B, N, L, C, H, W = data.shape + BM, NM, V = mask.shape + + data = data.view(B*N, L, C, H, W) # Flaten bags of windows + mask = mask.view(BM*NM, V) + data = data.contiguous() + mask = mask.contiguous() + + video_features, pred_key_frame = self._frame_wise_pass(data, key_frame, mask) #[B*N, D] + + # --- Window-wise pass --- + video_features = video_features.view(B, N, -1) + vid_emb, alpha = self.win_mil_att(video_features) + vid_emb = self.win_pool(vid_emb) + + logits = self.classifier(vid_emb) + + return logits, pred_key_frame, alpha + + def _shared_step(self, batch, val_stpe:bool=False): + + match (len(batch)): + case 5: + data, label, mask, key_frame, meta = batch + case 6: + data, label, mask, key_frame, soft_label, meta = batch + case _: + raise NotImplementedError() + + #self._log_x_stats(data, "TRAIN/x" if not val_stpe else "VAL/X") + + logits, key_frame_logits, alpha_w = self(data, key_frame, mask) + + key_frame_logits = soft_label if key_frame_logits is None else key_frame_logits + + label = label.float() + logits = logits.flatten() + + soft_loss = self.kd_loss(logits, key_frame_logits.flatten()) + hard_loss = self.loss(logits, label) + attn_win_loss = self._attn_reg_lambda() * self.attn_loss(alpha_w) + + total_loss = hard_loss + soft_loss + attn_win_loss + + return total_loss, label, logits, [soft_loss, alpha_w, meta["case_id"], attn_win_loss] + + def _log_x_stats(self, x, tag): + self.print( + f"{tag} mode={'train' if self.training else 'eval'} " + f"dtype={x.dtype} min={x.min().item():.3f} max={x.max().item():.3f} " + f"mean={x.mean().item():.3f} std={x.std().item():.3f}" + ) + + def training_step(self, batch, batch_idx): + loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, False) + + + + self.log("train_loss", loss, on_epoch=True) + self.log("train_soft_loss", extra_losses[0], on_epoch=True) + self.log("train_attn_loss", extra_losses[3], on_epoch=True) + self.train_acc.update(predicted_labels, true_labels) + self.log("train_acc", self.train_acc, on_epoch=True, on_step=False) + + return loss + + def validation_step(self, batch, batch_idx): + loss, true_labels, predicted_labels, extra_losses = self._shared_step(batch, True) + + val_main_loss = loss - extra_losses[3] # remove attention regularizer + + self.log("val_loss", val_main_loss, on_epoch=True, on_step=False) + self.log("val_soft_loss", extra_losses[0], on_epoch=True) + self.log("val_attn_loss", extra_losses[3], on_epoch=True) + self.log("val_total_loss", loss, on_epoch=True, on_step=False) + self.val_acc.update(predicted_labels, true_labels) + self.log("val_acc", self.val_acc, on_epoch=True, on_step=False) + + def test_step(self, batch, batch_idx): + _, true_labels, predicted_labels, _ = self._shared_step(batch, True) + + self.test_acc.update(predicted_labels, true_labels) + self.f1ScoreTest.update(predicted_labels, true_labels) + self.log("test_acc", self.test_acc, on_epoch=True, on_step=False) + self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False) + + def on_validation_epoch_end(self): + pass + +class RARP_NVB_Multi_MOD_MIL_TESTMode(RARP_NVB_Multi_MOD_MIL): + def __init__(self, num_classes, temporal="gru", cnn_name="resnet18", dropout=0.2, pre_train=False, lr=0.0003, weight_decay=0.05, epochs=50, warmup_epochs=3, label_smoothing=0, frizze_cnn=True, Hybrid_TS_weights="", attn_reg_weight = 0.02, attn_entropy_target = 0.4, attn_reg_warmup_epochs = 5, FOLD = None): + super().__init__(num_classes, temporal, cnn_name, dropout, pre_train, lr, weight_decay, epochs, warmup_epochs, label_smoothing, frizze_cnn, Hybrid_TS_weights, attn_reg_weight, attn_entropy_target, attn_reg_warmup_epochs, FOLD) + + self.FOLD = FOLD + self.Predictions = [] + self.Labels = [] + self._test_results = None + self.loaded_ckpt_epoch = None + + self.test_records = [] + + def on_load_checkpoint(self, checkpoint: dict): + self.loaded_ckpt_epoch = checkpoint.get("epoch", None) + + def on_test_epoch_start(self): + self.Predictions = [] + self.Labels = [] + self._test_results = None + self.test_records = [] + + def test_step(self, batch, batch_idx): + _, true_labels, predicted_labels, extra = self._shared_step(batch, True) + + probs = torch.sigmoid(predicted_labels) + B = probs.shape[0] + + self.Predictions.append(probs) + self.Labels.append(true_labels) + + for b in range(B): + rec = { + "case_id": extra[2][b].item(), + "y_true": int(true_labels[b].item()), + "y_pred": (probs[b] > 0.5).int().item(), + "prob": float(probs[b].item()), + "alpha": extra[1][b].flatten().cpu().numpy() + } + + self.test_records.append(rec) + + def on_test_epoch_end(self): + out_dir = Path(self.trainer.default_root_dir) / f"test_reports/FOLD_{self.FOLD}" + out_dir.mkdir(parents=True, exist_ok=True) + + rows = [] + for r in self.test_records: + row = { + "case_id": r["case_id"], + "y_true": r["y_true"], + "y_pred": r["y_pred"], + "prob": r["prob"], + } + # store attention weights as separate columns alpha_0...alpha_{W-1} + alpha = r["alpha"] + for i, a in enumerate(alpha): + row[f"alpha_{i:02d}"] = float(a) + rows.append(row) + + df = pd.DataFrame(rows) + df.to_csv(out_dir / f"mil_test_predictions_epoch{self.loaded_ckpt_epoch}.csv", index=False) + + predictions = torch.cat(self.Predictions) + labels = torch.cat(self.Labels).int() + + device = self.device + + acc = torchmetrics.Accuracy('binary').to(device)(predictions, labels) + precision = torchmetrics.Precision('binary').to(device)(predictions, labels) + recall = torchmetrics.Recall('binary').to(device)(predictions, labels) + auc = torchmetrics.AUROC('binary').to(device)(predictions, labels) + f1Score = torchmetrics.F1Score('binary').to(device)(predictions, labels) + specificty = torchmetrics.Specificity("binary").to(device)(predictions, labels) + + table = [ + ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""] + ] + + for i in range(2): + aucCurve = torchmetrics.ROC("binary").to(device) + fpr, tpr, thhols = aucCurve(predictions, labels) + index = torch.argmax(tpr - fpr) + th2 = (recall + specificty - 1).item() + th2 = 0.5 if th2 <= 0 else th2 + th1 = thhols[index].item() if i == 0 else th2 + accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(predictions, labels) + precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(predictions, labels) + recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(predictions, labels) + specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(predictions, labels) + f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(predictions, labels) + #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device) + #cm2.update(Predictions, Labels) + #_, ax = cm2.plot() + #ax.set_title(f"NVB Classifier (th={th1:.4f})") + table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", self.loaded_ckpt_epoch]) + + self._test_results = table \ No newline at end of file