Newer
Older
RARP_server / I3D_RestNet50.py
import torch
import torchvision
import inflate
import lightning as L


def Inflate_2D_to_3D(Layer:torch.nn.Module, framesDepth:int=3):
    newLayer3D = Layer
    
    if isinstance(Layer, torch.nn.Conv2d):
        newLayer3D = torch.nn.Conv3d(
            in_channels=Layer.in_channels,
            out_channels=Layer.out_channels,
            kernel_size=(framesDepth, Layer.kernel_size[0], Layer.kernel_size[1]),
            stride=(1, Layer.stride[0], Layer.stride[1]),
            padding=(1 if Layer.padding[0] > 0 else 0, Layer.padding[0], Layer.padding[1]),
            bias=Layer.bias is not None
        )
        
        with torch.no_grad():
            newLayer3D.weight.copy_(Layer.weight.unsqueeze(2).repeat(1, 1, framesDepth, 1, 1) / framesDepth)
            if Layer.bias is not None:
                newLayer3D.bias.copy_(Layer.bias)
                
    elif isinstance(Layer, torch.nn.AdaptiveAvgPool2d):
        newLayer3D = torch.nn.AdaptiveAvgPool3d((1,1,1))
    elif isinstance(Layer, torch.nn.MaxPool2d):
        if isinstance(Layer.kernel_size, int):
            Layer.kernel_size = (Layer.kernel_size, Layer.kernel_size)
        if isinstance(Layer.stride, int):
            Layer.stride = (Layer.stride, Layer.stride)
        if isinstance(Layer.padding, int):
            Layer.padding = (Layer.padding, Layer.padding)
        newLayer3D = torch.nn.MaxPool3d(
            kernel_size=(Layer.kernel_size[0], Layer.kernel_size[0], Layer.kernel_size[1]),
            stride=(Layer.stride[0], Layer.stride[0], Layer.stride[1]),
            padding=(Layer.padding[0], Layer.padding[0], Layer.padding[1]),
        )
    elif isinstance(Layer, torch.nn.BatchNorm2d):
        newLayer3D = torch.nn.BatchNorm3d(
            num_features=Layer.num_features,
        )
    elif isinstance(Layer, torch.nn.AvgPool2d):
        newLayer3D = torch.nn.AvgPool3d(
            kernel_size=(framesDepth, Layer.kernel_size[0], Layer.kernel_size[1]),
            stride=(1, Layer.stride[0], Layer.stride[1]),
            padding=(1, Layer.padding[0], Layer.padding[1])
        )
    
    return newLayer3D

def Recurcive_InflateModel(Model:torch.nn.Module):
    for name, module in Model.named_children():
        if len([*module.children()]) > 0:
            Recurcive_InflateModel(module) 
        else:
            setattr(Model, name, Inflate_2D_to_3D(module, 3))
            
def inflate_downsample(downsample2d, time_stride=1):
    downsample3d = torch.nn.Sequential(
        inflate.inflate_conv(downsample2d[0], time_dim=1, time_stride=time_stride, center=True),
        inflate.inflate_batch_norm(downsample2d[1])
    )
    return downsample3d

class Bottleneck3d(torch.nn.Module):
    def __init__(self, bottleneck2d):
        super(Bottleneck3d, self).__init__()

        spatial_stride = bottleneck2d.conv2.stride[0]

        self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=1, center=True)
        self.bn1 = inflate.inflate_batch_norm(bottleneck2d.bn1)

        self.conv2 = inflate.inflate_conv(
            bottleneck2d.conv2,
            time_dim=3,
            time_padding=1,
            time_stride=spatial_stride,
            center=True
        )
        self.bn2 = inflate.inflate_batch_norm(bottleneck2d.bn2)

        self.conv3 = inflate.inflate_conv(bottleneck2d.conv3, time_dim=1, center=True)
        self.bn3 = inflate.inflate_batch_norm(bottleneck2d.bn3)

        self.relu = torch.nn.ReLU(inplace=True)

        if bottleneck2d.downsample is not None:
            self.downsample = inflate_downsample(bottleneck2d.downsample, time_stride=spatial_stride)
        else:
            self.downsample = None

        self.stride = bottleneck2d.stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out

def inflate_reslayer(reslayer2d):
    reslayers3d = []
    for layer2d in reslayer2d:
        layer3d = Bottleneck3d(layer2d)
        reslayers3d.append(layer3d)
    return torch.nn.Sequential(*reslayers3d)
            
class I3DResNet50(torch.nn.Module):
    def __init__(self, RN50Model:torch.nn.Module) -> None:
        super().__init__()
        
        self.conv1 = inflate.inflate_conv(RN50Model.conv1, time_dim=7, time_stride=2, time_padding=3, center=True)
        self.bn1 = inflate.inflate_batch_norm(RN50Model.bn1)
        self.relu = torch.nn.ReLU(True)
        self.maxpool = inflate.inflate_pool(RN50Model.maxpool, time_dim=3, time_padding=1, time_stride=2)
        
        self.layer1 = inflate_reslayer(RN50Model.layer1)
        self.layer2 = inflate_reslayer(RN50Model.layer2)
        self.layer3 = inflate_reslayer(RN50Model.layer3)
        self.layer4 = inflate_reslayer(RN50Model.layer4)
        
        self.avgpool = inflate.inflate_pool(RN50Model.avgpool, time_dim=1)
        self.fc = inflate.inflate_linear(RN50Model.fc, 1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x_reshape = x.view(x.size(0), -1)
        x = self.fc(x_reshape)
        
        return x
        
"""
if __name__ == "__main__":
    torch.set_float32_matmul_precision('medium')
    torch.backends.cudnn.deterministic = True
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    
    model = I3DResNet50(model).to(device) 
    model.eval()
       
      
    print("Model Test")
    
    domyTest = torch.rand(4, 3, 600, 224, 224).to(device)
    
    outPut = model(domyTest)
    
    print("out_shape", outPut.shape)
    outPut = torch.softmax(outPut, dim=1)
    print("out", outPut.argmax(dim=1))
"""