"""
VAN → VAN3D (I3D-style inflation)
- Accepts video tensors [B, C, T, H, W]
- Replaces 2D ops (Conv2d/BN2d/DWConv) with 3D counterparts
- Inflates 2D weights into 3D following I3D: replicate over time and divide by kT
- Preserves pretrained 2D VAN weights from the provided van.py
Usage
-----
from van3d_i3d_inflate import build_van3d_from_2d
model3d = build_van3d_from_2d(
arch="van_b2", # one of: van_b0..van_b6 from your van.py
pretrained=True,
temporal_kernel=3, # kT used in depthwise convs (temporal modeling)
temporal_stride_stages=(1,1,1,1), # optional temporal downsampling per stage
)
x = torch.randn(2, 3, 8, 224, 224)
y = model3d(x) # [B, num_classes]
Notes
-----
* 1x1 convs become 1x1x1 (no temporal mixing). Depthwise 3x3 becomes kT×3×3.
* LKA spatial kernels become (1×5×5) and (1×7×7 dilated). You can set kT>1 for these too if desired.
* LayerNorm is left as nn.LayerNorm(C) and is applied over the flattened (T*H*W, C), then restored.
"""
from typing import Tuple, Dict, Any
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# import the user's 2D VAN definitions
import van as van2d
# -------------------------
# Utility: weight inflation
# -------------------------
def inflate_conv_weight_2d_to_3d(w2d: torch.Tensor, kT: int) -> torch.Tensor:
"""Inflate a 2D conv kernel [out, in, kH, kW] to 3D [out, in, kT, kH, kW].
I3D rule: replicate across time and divide by kT so response is preserved.
"""
if w2d.ndim != 4:
raise ValueError("Expected 4D [out,in,kH,kW] weight for 2D conv")
if kT == 1:
return w2d.unsqueeze(2) # [out,in,1,kH,kW]
# replicate and normalize
w3d = w2d.unsqueeze(2).repeat(1, 1, kT, 1, 1) / float(kT)
return w3d
def copy_bn2d_to_bn3d(bn3d: nn.BatchNorm3d, state2d: Dict[str, torch.Tensor], prefix2d: str):
bn3d.weight.data.copy_(state2d[prefix2d + ".weight"]) # gamma
bn3d.bias.data.copy_(state2d[prefix2d + ".bias"]) # beta
bn3d.running_mean.data.copy_(state2d[prefix2d + ".running_mean"])
bn3d.running_var.data.copy_(state2d[prefix2d + ".running_var"])
# num_batches_tracked may or may not exist
if prefix2d + ".num_batches_tracked" in state2d:
bn3d.num_batches_tracked.data.copy_(state2d[prefix2d + ".num_batches_tracked"]) # type: ignore
# -------------------------
# 3D building blocks
# -------------------------
class DWConv3D(nn.Module):
def __init__(self, dim: int):
super().__init__()
# depthwise: groups=dim, temporal kernel configurable later
self.kT = 3
self.op = nn.Conv3d(dim, dim, kernel_size=(self.kT, 3, 3), stride=1,
padding=(self.kT // 2, 1, 1), groups=dim, bias=True)
def set_temporal_kernel(self, kT: int):
self.kT = kT
self.op = nn.Conv3d(
self.op.in_channels,
self.op.out_channels,
kernel_size=(kT, 3, 3),
stride=1,
padding=(kT // 2, 1, 1),
groups=self.op.in_channels,
bias=True,
)
def forward(self, x):
return self.op(x)
class Mlp3D(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
# 1x1x1 pointwise convs
self.fc1 = nn.Conv3d(in_features, hidden_features, kernel_size=1)
self.dwconv = DWConv3D(hidden_features)
self.act = act_layer()
self.fc2 = nn.Conv3d(hidden_features, out_features, kernel_size=1)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
van2d.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, (nn.Conv3d,)):
k = m.kernel_size
fan_out = k[0] * k[1] * k[2] * m.out_channels // m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LKA3D(nn.Module):
"""Large Kernel Attention in 3D.
Default: temporal kernel = 1 (purely spatial attention). Set kT>1 if desired.
"""
def __init__(self, dim: int, kT: int = 1):
super().__init__()
self.kT = kT
self.conv0 = nn.Conv3d(dim, dim, kernel_size=(kT, 5, 5), padding=(kT // 2, 2, 2), groups=dim)
self.conv_spatial = nn.Conv3d(
dim, dim, kernel_size=(kT, 7, 7), stride=1, padding=(kT // 2, 9, 9), groups=dim, dilation=(1, 3, 3)
)
self.conv1 = nn.Conv3d(dim, dim, kernel_size=1)
def forward(self, x):
u = x
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)
return u * attn
class Attention3D(nn.Module):
def __init__(self, d_model: int, lka_kT: int = 1):
super().__init__()
self.proj_1 = nn.Conv3d(d_model, d_model, kernel_size=1)
self.activation = nn.GELU()
self.spatial_gating_unit = LKA3D(d_model, kT=lka_kT)
self.proj_2 = nn.Conv3d(d_model, d_model, kernel_size=1)
def forward(self, x):
shortcut = x
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
return x + shortcut
class Block3D(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, lka_kT: int = 1, dw_kT: int = 3):
super().__init__()
self.norm1 = nn.BatchNorm3d(dim)
self.attn = Attention3D(dim, lka_kT)
self.drop_path = van2d.DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.BatchNorm3d(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp3D(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
# set temporal kernel for depthwise conv inside MLP
self.mlp.dwconv.set_temporal_kernel(dw_kT)
layer_scale_init_value = 1e-2
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
van2d.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, (nn.Conv3d,)):
k = m.kernel_size
fan_out = k[0] * k[1] * k[2] * m.out_channels // m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
# (B, C, T, H, W)
x = x + self.drop_path(self.layer_scale_1.view(1, -1, 1, 1, 1) * self.attn(self.norm1(x)))
x = x + self.drop_path(self.layer_scale_2.view(1, -1, 1, 1, 1) * self.mlp(self.norm2(x)))
return x
class OverlapPatchEmbed3D(nn.Module):
def __init__(self, img_size=224, t_size=8, patch_size=7, stride=4, t_stride=1, in_chans=3, embed_dim=768):
super().__init__()
# 3D conv with temporal kernel t_k (often 3 or 1), spatial as VAN
patch_size = van2d.to_2tuple(patch_size)
self.proj = nn.Conv3d(
in_chans, embed_dim,
kernel_size=(1, patch_size[0], patch_size[1]), # by default no temporal mixing here
stride=(t_stride, stride, stride),
padding=(0, patch_size[0] // 2, patch_size[1] // 2),
)
self.norm = nn.BatchNorm3d(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
van2d.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, (nn.Conv3d,)):
k = m.kernel_size
fan_out = k[0] * k[1] * k[2] * m.out_channels // m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
# x: (B, C, T, H, W)
x = self.proj(x)
_, _, T, H, W = x.shape
x = self.norm(x)
return x, T, H, W
class VAN3D(nn.Module):
def __init__(
self,
img_size=224,
t_size=8,
in_chans=3,
num_classes=1000,
embed_dims=[64, 128, 256, 512],
mlp_ratios=[4, 4, 4, 4],
drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3],
num_stages=4,
temporal_stride_stages=(1, 1, 1, 1),
lka_kT: int = 1,
dw_kT: int = 3,
flag: bool = False,
):
super().__init__()
if flag == False:
self.num_classes = num_classes
self.depths = depths
self.num_stages = num_stages
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(num_stages):
patch_embed = OverlapPatchEmbed3D(
img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
t_size=t_size,
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
t_stride=temporal_stride_stages[i],
in_chans=in_chans if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
)
block = nn.ModuleList([
Block3D(
dim=embed_dims[i],
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
drop_path=dpr[cur + j],
lka_kT=lka_kT,
dw_kT=dw_kT,
)
for j in range(depths[i])
])
norm = norm_layer(embed_dims[i])
cur += depths[i]
setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"block{i + 1}", block)
setattr(self, f"norm{i + 1}", norm)
self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
van2d.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, (nn.Conv3d,)):
k = m.kernel_size
fan_out = k[0] * k[1] * k[2] * m.out_channels // m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward_features(self, x):
B = x.shape[0]
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
block = getattr(self, f"block{i + 1}")
norm = getattr(self, f"norm{i + 1}")
x, T, H, W = patch_embed(x) # (B,C,T,H,W)
for blk in block:
x = blk(x)
# apply LayerNorm over last dim C, so reshape to (B, THW, C)
x = x.permute(0, 2, 3, 4, 1).contiguous() # (B,T,H,W,C)
x = x.view(B, T * H * W, -1)
x = norm(x)
if i != self.num_stages - 1:
x = x.view(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
# final token average: global mean over T,H,W
x = x.mean(dim=1) # (B, C)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
# ------------------------------------------------
# Builder: create VAN3D and inflate from 2D weights
# ------------------------------------------------
def _get_2d_model(arch: str, pretrained: bool, num_classes: int) -> nn.Module:
ctor = getattr(van2d, arch)
m2d = ctor(pretrained=pretrained, num_classes=num_classes)
return m2d
def _make_3d_from_2d_cfg(m2d: nn.Module, temporal_stride_stages, lka_kT, dw_kT) -> VAN3D:
# infer config from 2D model attributes
# These attributes exist by construction in VAN
embed_dims = []
depths = m2d.depths
num_stages = m2d.num_stages
for i in range(num_stages):
pe: nn.Module = getattr(m2d, f"patch_embed{i+1}")
embed_dims.append(pe.norm.num_features)
# try to infer img_size from stage0 conv padding/stride (optional)
return VAN3D(
img_size=224,
t_size=8,
in_chans=3,
num_classes=getattr(m2d, "num_classes", 1000),
embed_dims=embed_dims,
mlp_ratios=[4]*num_stages,
drop_rate=0.,
drop_path_rate=0.,
depths=depths,
num_stages=num_stages,
temporal_stride_stages=temporal_stride_stages,
lka_kT=lka_kT,
dw_kT=dw_kT,
)
def _inflate_and_load(m3d: VAN3D, m2d_state: Dict[str, torch.Tensor], dw_kT: int, lka_kT: int):
"""Walk through m3d and load inflated weights from m2d_state by name mapping."""
sd3d = m3d.state_dict()
new_sd = {}
def map_3d_to_2d_name(name: str) -> str:
"""Map a 3D parameter name to its 2D counterpart.
Notably, DWConv3D stores conv as `.dwconv.op` while 2D uses `.dwconv`.
"""
if ".dwconv.op." in name:
name = name.replace(".dwconv.op.", ".dwconv.")
if ".mlp.dwconv.op." in name: # redundant but safe
name = name.replace(".mlp.dwconv.op.", ".mlp.dwconv.")
return name
def try_get(name):
if name in m2d_state:
return m2d_state[name]
# be tolerant: try mapped variant
mapped = map_3d_to_2d_name(name)
if mapped in m2d_state:
return m2d_state[mapped]
raise KeyError(name)
for k3d in sd3d.keys():
v3d = sd3d[k3d]
# Simple copies for classifier head if shapes match
if k3d.startswith("head."):
k2d = k3d
if k2d in m2d_state and m2d_state[k2d].shape == v3d.shape:
new_sd[k3d] = m2d_state[k2d]
continue
# Patch embeddings
if ".proj.weight" in k3d and k3d.startswith("patch_embed"):
w2d = try_get(k3d)
new_sd[k3d] = inflate_conv_weight_2d_to_3d(w2d.detach(), kT=1)
continue
if ".proj.bias" in k3d and k3d.startswith("patch_embed"):
new_sd[k3d] = try_get(k3d)
continue
if ".norm." in k3d and k3d.startswith("patch_embed"):
# BatchNorm3d ← BatchNorm2d handled below
pass
# Blocks: attention and mlp convs
if ".attn." in k3d or ".mlp." in k3d or ".dwconv." in k3d or ".proj_" in k3d:
if k3d.endswith("weight") and v3d.ndim == 5:
k2d = map_3d_to_2d_name(k3d)
w2d = try_get(k2d).detach()
# decide kT by layer type
if ("dwconv.op" in k3d) or ("mlp.dwconv.op" in k3d) or ("dwconv." in k3d and v3d.shape[2] > 1):
kT = dw_kT
elif ("spatial_gating_unit.conv0" in k3d) or ("spatial_gating_unit.conv_spatial" in k3d):
kT = lka_kT
else:
kT = 1 # 1x1 convs
new_sd[k3d] = inflate_conv_weight_2d_to_3d(w2d, kT=kT)
continue
if k3d.endswith("bias"):
k2d = map_3d_to_2d_name(k3d)
if k2d in m2d_state:
new_sd[k3d] = m2d_state[k2d]
continue
# BatchNorm: Copy stats/params directly (same num_features)
if ".norm" in k3d and any(s in k3d for s in ["weight", "bias", "running_mean", "running_var", "num_batches_tracked"]):
k2d = map_3d_to_2d_name(k3d)
if k2d in m2d_state and m2d_state[k2d].shape == v3d.shape:
new_sd[k3d] = m2d_state[k2d]
continue
# LayerNorm: identical tensors
if any(seg in k3d for seg in [".norm1.", ".norm2."]) and any(s in k3d for s in ["weight", "bias"]):
k2d = map_3d_to_2d_name(k3d)
if k2d in m2d_state and m2d_state[k2d].shape == v3d.shape:
new_sd[k3d] = m2d_state[k2d]
continue
# Layer scale params
if any(ls in k3d for ls in ["layer_scale_1", "layer_scale_2"]):
k2d = map_3d_to_2d_name(k3d)
if k2d in m2d_state:
new_sd[k3d] = m2d_state[k2d]
continue
# load partially (missing keys are likely OK, log for reference)
missing, unexpected = m3d.load_state_dict({**sd3d, **new_sd}, strict=False)
return missing, unexpected
def build_van3d_from_2d(
arch: str = "van_b2",
pretrained: bool = True,
num_classes: int = 1000,
temporal_kernel: int = 3,
temporal_stride_stages: Tuple[int, int, int, int] = (1, 1, 1, 1),
lka_kT: int = 1,
) -> VAN3D:
"""
Build VAN3D and inflate from a 2D VAN checkpoint (from your van.py URLs).
Parameters
----------
arch : str
One of van_b0..van_b6.
pretrained : bool
Whether to download/load 2D ImageNet weights before inflation.
num_classes : int
Number of classes for the 3D head. If different from 1000, the inflated
head will be discarded (randomly initialized) to avoid shape mismatch.
temporal_kernel : int
kT used in depthwise 3×3 depthwise convs (temporal modeling).
temporal_stride_stages : tuple
Temporal stride (downsampling) to apply per stage in patch embedding.
lka_kT : int
Temporal kernel for LKA convs (default 1 = spatial-only).
"""
m2d = _get_2d_model(arch, pretrained=pretrained, num_classes=1000)
# create 3D counterpart mirroring depths/embed_dims
m3d = _make_3d_from_2d_cfg(m2d, temporal_stride_stages, lka_kT=lka_kT, dw_kT=temporal_kernel)
# inflate and load
state2d = m2d.state_dict()
missing, unexpected = _inflate_and_load(m3d, state2d, dw_kT=temporal_kernel, lka_kT=lka_kT)
# reset / resize classifier if needed
if num_classes != 1000 and isinstance(m3d.head, nn.Linear):
in_f = m3d.head.in_features
m3d.head = nn.Linear(in_f, num_classes)
# helpful attributes
m3d.missing_keys_from_inflation = missing
m3d.unexpected_keys_from_inflation = unexpected
return m3d
if __name__ == "__main__":
# quick smoke test
model = build_van3d_from_2d("van_b2", pretrained=False, temporal_kernel=3)
x = torch.randn(1, 3, 4, 224, 224)
y = model(x)
print(y.shape)