import torch
import torchvision
import inflate
def Inflate_2D_to_3D(Layer:torch.nn.Module, framesDepth:int=3):
newLayer3D = Layer
if isinstance(Layer, torch.nn.Conv2d):
newLayer3D = torch.nn.Conv3d(
in_channels=Layer.in_channels,
out_channels=Layer.out_channels,
kernel_size=(framesDepth, Layer.kernel_size[0], Layer.kernel_size[1]),
stride=(1, Layer.stride[0], Layer.stride[1]),
padding=(1 if Layer.padding[0] > 0 else 0, Layer.padding[0], Layer.padding[1]),
bias=Layer.bias is not None
)
with torch.no_grad():
newLayer3D.weight.copy_(Layer.weight.unsqueeze(2).repeat(1, 1, framesDepth, 1, 1) / framesDepth)
if Layer.bias is not None:
newLayer3D.bias.copy_(Layer.bias)
elif isinstance(Layer, torch.nn.AdaptiveAvgPool2d):
newLayer3D = torch.nn.AdaptiveAvgPool3d((1,1,1))
elif isinstance(Layer, torch.nn.MaxPool2d):
if isinstance(Layer.kernel_size, int):
Layer.kernel_size = (Layer.kernel_size, Layer.kernel_size)
if isinstance(Layer.stride, int):
Layer.stride = (Layer.stride, Layer.stride)
if isinstance(Layer.padding, int):
Layer.padding = (Layer.padding, Layer.padding)
newLayer3D = torch.nn.MaxPool3d(
kernel_size=(Layer.kernel_size[0], Layer.kernel_size[0], Layer.kernel_size[1]),
stride=(Layer.stride[0], Layer.stride[0], Layer.stride[1]),
padding=(Layer.padding[0], Layer.padding[0], Layer.padding[1]),
)
elif isinstance(Layer, torch.nn.BatchNorm2d):
newLayer3D = torch.nn.BatchNorm3d(
num_features=Layer.num_features,
)
elif isinstance(Layer, torch.nn.AvgPool2d):
newLayer3D = torch.nn.AvgPool3d(
kernel_size=(framesDepth, Layer.kernel_size[0], Layer.kernel_size[1]),
stride=(1, Layer.stride[0], Layer.stride[1]),
padding=(1, Layer.padding[0], Layer.padding[1])
)
return newLayer3D
def Recurcive_InflateModel(Model:torch.nn.Module):
for name, module in Model.named_children():
if len([*module.children()]) > 0:
Recurcive_InflateModel(module)
else:
setattr(Model, name, Inflate_2D_to_3D(module, 3))
def inflate_downsample(downsample2d, time_stride=1):
downsample3d = torch.nn.Sequential(
inflate.inflate_conv(downsample2d[0], time_dim=1, time_stride=time_stride, center=True),
inflate.inflate_batch_norm(downsample2d[1])
)
return downsample3d
class Bottleneck3d(torch.nn.Module):
def __init__(self, bottleneck2d):
super(Bottleneck3d, self).__init__()
spatial_stride = bottleneck2d.conv2.stride[0]
self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=1, center=True)
self.bn1 = inflate.inflate_batch_norm(bottleneck2d.bn1)
self.conv2 = inflate.inflate_conv(
bottleneck2d.conv2,
time_dim=3,
time_padding=1,
time_stride=spatial_stride,
center=True
)
self.bn2 = inflate.inflate_batch_norm(bottleneck2d.bn2)
self.conv3 = inflate.inflate_conv(bottleneck2d.conv3, time_dim=1, center=True)
self.bn3 = inflate.inflate_batch_norm(bottleneck2d.bn3)
self.relu = torch.nn.ReLU(inplace=True)
if bottleneck2d.downsample is not None:
self.downsample = inflate_downsample(bottleneck2d.downsample, time_stride=spatial_stride)
else:
self.downsample = None
self.stride = bottleneck2d.stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def inflate_reslayer(reslayer2d):
reslayers3d = []
for layer2d in reslayer2d:
layer3d = Bottleneck3d(layer2d)
reslayers3d.append(layer3d)
return torch.nn.Sequential(*reslayers3d)
class I3DResNet50(torch.nn.Module):
def __init__(self, RN50Model:torch.nn.Module, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.conv1 = inflate.inflate_conv(RN50Model.conv1, time_dim=7, time_stride=2, time_padding=3, center=True)
self.bn1 = inflate.inflate_batch_norm(RN50Model.bn1)
self.relu = torch.nn.ReLU(True)
self.maxpool = inflate.inflate_pool(RN50Model.maxpool, time_dim=3, time_padding=1, time_stride=2)
self.layer1 = inflate_reslayer(RN50Model.layer1)
self.layer2 = inflate_reslayer(RN50Model.layer2)
self.layer3 = inflate_reslayer(RN50Model.layer3)
self.layer4 = inflate_reslayer(RN50Model.layer4)
self.avgpool = inflate.inflate_pool(RN50Model.avgpool, time_dim=1)
self.fc = inflate.inflate_linear(RN50Model.fc, 1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x_reshape = x.view(x.size(0), -1)
x = self.fc(x_reshape)
return x
if __name__ == "__main__":
torch.set_float32_matmul_precision('medium')
torch.backends.cudnn.deterministic = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
model = I3DResNet50(model).to(device)
model.eval()
print("Model Test")
domyTest = torch.rand(8, 3, 15, 224, 224).to(device)
outPut = model(domyTest)
print("out_shape", outPut.shape)
outPut = torch.softmax(outPut, dim=1)
print("out", outPut.argmax(dim=1))