Newer
Older
RARP / van_3d_i3d_inflate.py
@delAguila delAguila 27 days ago 19 KB Final Commit.
"""
VAN → VAN3D (I3D-style inflation)

- Accepts video tensors [B, C, T, H, W]
- Replaces 2D ops (Conv2d/BN2d/DWConv) with 3D counterparts
- Inflates 2D weights into 3D following I3D: replicate over time and divide by kT
- Preserves pretrained 2D VAN weights from the provided van.py

Usage
-----
from van3d_i3d_inflate import build_van3d_from_2d
model3d = build_van3d_from_2d(
    arch="van_b2",          # one of: van_b0..van_b6 from your van.py
    pretrained=True,
    temporal_kernel=3,       # kT used in depthwise convs (temporal modeling)
    temporal_stride_stages=(1,1,1,1), # optional temporal downsampling per stage
)

x = torch.randn(2, 3, 8, 224, 224)
y = model3d(x)  # [B, num_classes]

Notes
-----
* 1x1 convs become 1x1x1 (no temporal mixing). Depthwise 3x3 becomes kT×3×3.
* LKA spatial kernels become (1×5×5) and (1×7×7 dilated). You can set kT>1 for these too if desired.
* LayerNorm is left as nn.LayerNorm(C) and is applied over the flattened (T*H*W, C), then restored.
"""

from typing import Tuple, Dict, Any
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# import the user's 2D VAN definitions
import van as van2d


# -------------------------
# Utility: weight inflation
# -------------------------

def inflate_conv_weight_2d_to_3d(w2d: torch.Tensor, kT: int) -> torch.Tensor:
    """Inflate a 2D conv kernel [out, in, kH, kW] to 3D [out, in, kT, kH, kW].
    I3D rule: replicate across time and divide by kT so response is preserved.
    """
    if w2d.ndim != 4:
        raise ValueError("Expected 4D [out,in,kH,kW] weight for 2D conv")
    if kT == 1:
        return w2d.unsqueeze(2)  # [out,in,1,kH,kW]
    # replicate and normalize
    w3d = w2d.unsqueeze(2).repeat(1, 1, kT, 1, 1) / float(kT)
    return w3d


def copy_bn2d_to_bn3d(bn3d: nn.BatchNorm3d, state2d: Dict[str, torch.Tensor], prefix2d: str):
    bn3d.weight.data.copy_(state2d[prefix2d + ".weight"])  # gamma
    bn3d.bias.data.copy_(state2d[prefix2d + ".bias"])      # beta
    bn3d.running_mean.data.copy_(state2d[prefix2d + ".running_mean"])
    bn3d.running_var.data.copy_(state2d[prefix2d + ".running_var"])
    # num_batches_tracked may or may not exist
    if prefix2d + ".num_batches_tracked" in state2d:
        bn3d.num_batches_tracked.data.copy_(state2d[prefix2d + ".num_batches_tracked"])  # type: ignore


# -------------------------
# 3D building blocks
# -------------------------

class DWConv3D(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        # depthwise: groups=dim, temporal kernel configurable later
        self.kT = 3
        self.op = nn.Conv3d(dim, dim, kernel_size=(self.kT, 3, 3), stride=1,
                            padding=(self.kT // 2, 1, 1), groups=dim, bias=True)

    def set_temporal_kernel(self, kT: int):
        self.kT = kT
        self.op = nn.Conv3d(
            self.op.in_channels,
            self.op.out_channels,
            kernel_size=(kT, 3, 3),
            stride=1,
            padding=(kT // 2, 1, 1),
            groups=self.op.in_channels,
            bias=True,
        )

    def forward(self, x):
        return self.op(x)


class Mlp3D(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # 1x1x1 pointwise convs
        self.fc1 = nn.Conv3d(in_features, hidden_features, kernel_size=1)
        self.dwconv = DWConv3D(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Conv3d(hidden_features, out_features, kernel_size=1)
        self.drop = nn.Dropout(drop)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            van2d.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, (nn.Conv3d,)):
            k = m.kernel_size
            fan_out = k[0] * k[1] * k[2] * m.out_channels // m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class LKA3D(nn.Module):
    """Large Kernel Attention in 3D.
    Default: temporal kernel = 1 (purely spatial attention). Set kT>1 if desired.
    """
    def __init__(self, dim: int, kT: int = 1):
        super().__init__()
        self.kT = kT
        self.conv0 = nn.Conv3d(dim, dim, kernel_size=(kT, 5, 5), padding=(kT // 2, 2, 2), groups=dim)
        self.conv_spatial = nn.Conv3d(
            dim, dim, kernel_size=(kT, 7, 7), stride=1, padding=(kT // 2, 9, 9), groups=dim, dilation=(1, 3, 3)
        )
        self.conv1 = nn.Conv3d(dim, dim, kernel_size=1)

    def forward(self, x):
        u = x
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)
        return u * attn


class Attention3D(nn.Module):
    def __init__(self, d_model: int, lka_kT: int = 1):
        super().__init__()
        self.proj_1 = nn.Conv3d(d_model, d_model, kernel_size=1)
        self.activation = nn.GELU()
        self.spatial_gating_unit = LKA3D(d_model, kT=lka_kT)
        self.proj_2 = nn.Conv3d(d_model, d_model, kernel_size=1)

    def forward(self, x):
        shortcut = x
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        return x + shortcut


class Block3D(nn.Module):
    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, lka_kT: int = 1, dw_kT: int = 3):
        super().__init__()
        self.norm1 = nn.BatchNorm3d(dim)
        self.attn = Attention3D(dim, lka_kT)
        self.drop_path = van2d.DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = nn.BatchNorm3d(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp3D(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        # set temporal kernel for depthwise conv inside MLP
        self.mlp.dwconv.set_temporal_kernel(dw_kT)

        layer_scale_init_value = 1e-2
        self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            van2d.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, (nn.Conv3d,)):
            k = m.kernel_size
            fan_out = k[0] * k[1] * k[2] * m.out_channels // m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        # (B, C, T, H, W)
        x = x + self.drop_path(self.layer_scale_1.view(1, -1, 1, 1, 1) * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.layer_scale_2.view(1, -1, 1, 1, 1) * self.mlp(self.norm2(x)))
        return x


class OverlapPatchEmbed3D(nn.Module):
    def __init__(self, img_size=224, t_size=8, patch_size=7, stride=4, t_stride=1, in_chans=3, embed_dim=768):
        super().__init__()
        # 3D conv with temporal kernel t_k (often 3 or 1), spatial as VAN
        patch_size = van2d.to_2tuple(patch_size)
        self.proj = nn.Conv3d(
            in_chans, embed_dim,
            kernel_size=(1, patch_size[0], patch_size[1]),  # by default no temporal mixing here
            stride=(t_stride, stride, stride),
            padding=(0, patch_size[0] // 2, patch_size[1] // 2),
        )
        self.norm = nn.BatchNorm3d(embed_dim)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            van2d.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, (nn.Conv3d,)):
            k = m.kernel_size
            fan_out = k[0] * k[1] * k[2] * m.out_channels // m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        # x: (B, C, T, H, W)
        x = self.proj(x)
        _, _, T, H, W = x.shape
        x = self.norm(x)
        return x, T, H, W


class VAN3D(nn.Module):
    def __init__(
        self,
        img_size=224,
        t_size=8,
        in_chans=3,
        num_classes=1000,
        embed_dims=[64, 128, 256, 512],
        mlp_ratios=[4, 4, 4, 4],
        drop_rate=0.,
        drop_path_rate=0.,
        norm_layer=nn.LayerNorm,
        depths=[3, 4, 6, 3],
        num_stages=4,
        temporal_stride_stages=(1, 1, 1, 1),
        lka_kT: int = 1,
        dw_kT: int = 3,
        flag: bool = False,
    ):
        super().__init__()
        if flag == False:
            self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0

        for i in range(num_stages):
            patch_embed = OverlapPatchEmbed3D(
                img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
                t_size=t_size,
                patch_size=7 if i == 0 else 3,
                stride=4 if i == 0 else 2,
                t_stride=temporal_stride_stages[i],
                in_chans=in_chans if i == 0 else embed_dims[i - 1],
                embed_dim=embed_dims[i],
            )

            block = nn.ModuleList([
                Block3D(
                    dim=embed_dims[i],
                    mlp_ratio=mlp_ratios[i],
                    drop=drop_rate,
                    drop_path=dpr[cur + j],
                    lka_kT=lka_kT,
                    dw_kT=dw_kT,
                )
                for j in range(depths[i])
            ])
            norm = norm_layer(embed_dims[i])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"block{i + 1}", block)
            setattr(self, f"norm{i + 1}", norm)

        self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            van2d.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, (nn.Conv3d,)):
            k = m.kernel_size
            fan_out = k[0] * k[1] * k[2] * m.out_channels // m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward_features(self, x):
        B = x.shape[0]
        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            norm = getattr(self, f"norm{i + 1}")
            x, T, H, W = patch_embed(x)  # (B,C,T,H,W)
            for blk in block:
                x = blk(x)
            # apply LayerNorm over last dim C, so reshape to (B, THW, C)
            x = x.permute(0, 2, 3, 4, 1).contiguous()       # (B,T,H,W,C)
            x = x.view(B, T * H * W, -1)
            x = norm(x)
            if i != self.num_stages - 1:
                x = x.view(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
        # final token average: global mean over T,H,W
        x = x.mean(dim=1)  # (B, C)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


# ------------------------------------------------
# Builder: create VAN3D and inflate from 2D weights
# ------------------------------------------------

def _get_2d_model(arch: str, pretrained: bool, num_classes: int) -> nn.Module:
    ctor = getattr(van2d, arch)
    m2d = ctor(pretrained=pretrained, num_classes=num_classes)
    return m2d


def _make_3d_from_2d_cfg(m2d: nn.Module, temporal_stride_stages, lka_kT, dw_kT) -> VAN3D:
    # infer config from 2D model attributes
    # These attributes exist by construction in VAN
    embed_dims = []
    depths = m2d.depths
    num_stages = m2d.num_stages
    for i in range(num_stages):
        pe: nn.Module = getattr(m2d, f"patch_embed{i+1}")
        embed_dims.append(pe.norm.num_features)
    # try to infer img_size from stage0 conv padding/stride (optional)
    return VAN3D(
        img_size=224,
        t_size=8,
        in_chans=3,
        num_classes=getattr(m2d, "num_classes", 1000),
        embed_dims=embed_dims,
        mlp_ratios=[4]*num_stages,
        drop_rate=0.,
        drop_path_rate=0.,
        depths=depths,
        num_stages=num_stages,
        temporal_stride_stages=temporal_stride_stages,
        lka_kT=lka_kT,
        dw_kT=dw_kT,
    )


def _inflate_and_load(m3d: VAN3D, m2d_state: Dict[str, torch.Tensor], dw_kT: int, lka_kT: int):
    """Walk through m3d and load inflated weights from m2d_state by name mapping."""
    sd3d = m3d.state_dict()
    new_sd = {}

    def map_3d_to_2d_name(name: str) -> str:
        """Map a 3D parameter name to its 2D counterpart.
        Notably, DWConv3D stores conv as `.dwconv.op` while 2D uses `.dwconv`.
        """
        if ".dwconv.op." in name:
            name = name.replace(".dwconv.op.", ".dwconv.")
        if ".mlp.dwconv.op." in name:  # redundant but safe
            name = name.replace(".mlp.dwconv.op.", ".mlp.dwconv.")
        return name

    def try_get(name):
        if name in m2d_state:
            return m2d_state[name]
        # be tolerant: try mapped variant
        mapped = map_3d_to_2d_name(name)
        if mapped in m2d_state:
            return m2d_state[mapped]
        raise KeyError(name)

    for k3d in sd3d.keys():
        v3d = sd3d[k3d]
        # Simple copies for classifier head if shapes match
        if k3d.startswith("head."):
            k2d = k3d
            if k2d in m2d_state and m2d_state[k2d].shape == v3d.shape:
                new_sd[k3d] = m2d_state[k2d]
                continue

        # Patch embeddings
        if ".proj.weight" in k3d and k3d.startswith("patch_embed"):
            w2d = try_get(k3d)
            new_sd[k3d] = inflate_conv_weight_2d_to_3d(w2d.detach(), kT=1)
            continue
        if ".proj.bias" in k3d and k3d.startswith("patch_embed"):
            new_sd[k3d] = try_get(k3d)
            continue

        if ".norm." in k3d and k3d.startswith("patch_embed"):
            # BatchNorm3d ← BatchNorm2d handled below
            pass

        # Blocks: attention and mlp convs
        if ".attn." in k3d or ".mlp." in k3d or ".dwconv." in k3d or ".proj_" in k3d:
            if k3d.endswith("weight") and v3d.ndim == 5:
                k2d = map_3d_to_2d_name(k3d)
                w2d = try_get(k2d).detach()
                # decide kT by layer type
                if ("dwconv.op" in k3d) or ("mlp.dwconv.op" in k3d) or ("dwconv." in k3d and v3d.shape[2] > 1):
                    kT = dw_kT
                elif ("spatial_gating_unit.conv0" in k3d) or ("spatial_gating_unit.conv_spatial" in k3d):
                    kT = lka_kT
                else:
                    kT = 1  # 1x1 convs
                new_sd[k3d] = inflate_conv_weight_2d_to_3d(w2d, kT=kT)
                continue
            if k3d.endswith("bias"):
                k2d = map_3d_to_2d_name(k3d)
                if k2d in m2d_state:
                    new_sd[k3d] = m2d_state[k2d]
                    continue

        # BatchNorm: Copy stats/params directly (same num_features)
        if ".norm" in k3d and any(s in k3d for s in ["weight", "bias", "running_mean", "running_var", "num_batches_tracked"]):
            k2d = map_3d_to_2d_name(k3d)
            if k2d in m2d_state and m2d_state[k2d].shape == v3d.shape:
                new_sd[k3d] = m2d_state[k2d]
                continue

        # LayerNorm: identical tensors
        if any(seg in k3d for seg in [".norm1.", ".norm2."]) and any(s in k3d for s in ["weight", "bias"]):
            k2d = map_3d_to_2d_name(k3d)
            if k2d in m2d_state and m2d_state[k2d].shape == v3d.shape:
                new_sd[k3d] = m2d_state[k2d]
                continue

        # Layer scale params
        if any(ls in k3d for ls in ["layer_scale_1", "layer_scale_2"]):
            k2d = map_3d_to_2d_name(k3d)
            if k2d in m2d_state:
                new_sd[k3d] = m2d_state[k2d]
                continue

    # load partially (missing keys are likely OK, log for reference)
    missing, unexpected = m3d.load_state_dict({**sd3d, **new_sd}, strict=False)
    return missing, unexpected


def build_van3d_from_2d(
    arch: str = "van_b2",
    pretrained: bool = True,
    num_classes: int = 1000,
    temporal_kernel: int = 3,
    temporal_stride_stages: Tuple[int, int, int, int] = (1, 1, 1, 1),
    lka_kT: int = 1,
) -> VAN3D:
    """
    Build VAN3D and inflate from a 2D VAN checkpoint (from your van.py URLs).

    Parameters
    ----------
    arch : str
        One of van_b0..van_b6.
    pretrained : bool
        Whether to download/load 2D ImageNet weights before inflation.
    num_classes : int
        Number of classes for the 3D head. If different from 1000, the inflated
        head will be discarded (randomly initialized) to avoid shape mismatch.
    temporal_kernel : int
        kT used in depthwise 3×3 depthwise convs (temporal modeling).
    temporal_stride_stages : tuple
        Temporal stride (downsampling) to apply per stage in patch embedding.
    lka_kT : int
        Temporal kernel for LKA convs (default 1 = spatial-only).
    """
    m2d = _get_2d_model(arch, pretrained=pretrained, num_classes=1000)
    # create 3D counterpart mirroring depths/embed_dims
    m3d = _make_3d_from_2d_cfg(m2d, temporal_stride_stages, lka_kT=lka_kT, dw_kT=temporal_kernel)

    # inflate and load
    state2d = m2d.state_dict()
    missing, unexpected = _inflate_and_load(m3d, state2d, dw_kT=temporal_kernel, lka_kT=lka_kT)

    # reset / resize classifier if needed
    if num_classes != 1000 and isinstance(m3d.head, nn.Linear):
        in_f = m3d.head.in_features
        m3d.head = nn.Linear(in_f, num_classes)

    # helpful attributes
    m3d.missing_keys_from_inflation = missing
    m3d.unexpected_keys_from_inflation = unexpected
    return m3d


if __name__ == "__main__":
    # quick smoke test
    model = build_van3d_from_2d("van_b2", pretrained=False, temporal_kernel=3)
    x = torch.randn(1, 3, 4, 224, 224)
    y = model(x)
    print(y.shape)