Newer
Older
Demo-Maker / modules / posenet / models / mobilenet_v1.py
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict


def _to_output_strided_layers(convolution_def, output_stride):
    current_stride = 1
    rate = 1
    block_id = 0
    buff = []
    for c in convolution_def:
        conv_type = c[0]
        inp = c[1]
        outp = c[2]
        stride = c[3]

        if current_stride == output_stride:
            layer_stride = 1
            layer_rate = rate
            rate *= stride
        else:
            layer_stride = stride
            layer_rate = 1
            current_stride *= stride

        buff.append({
            'block_id': block_id,
            'conv_type': conv_type,
            'inp': inp,
            'outp': outp,
            'stride': layer_stride,
            'rate': layer_rate,
            'output_stride': current_stride
        })
        block_id += 1

    return buff


def _get_padding(kernel_size, stride, dilation):
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding


class InputConv(nn.Module):
    def __init__(self, inp, outp, k=3, stride=1, dilation=1):
        super(InputConv, self).__init__()
        self.conv = nn.Conv2d(
            inp, outp, k, stride, padding=_get_padding(k, stride, dilation), dilation=dilation)

    def forward(self, x):
        return F.relu6(self.conv(x))


class SeperableConv(nn.Module):
    def __init__(self, inp, outp, k=3, stride=1, dilation=1):
        super(SeperableConv, self).__init__()
        self.depthwise = nn.Conv2d(
            inp, inp, k, stride,
            padding=_get_padding(k, stride, dilation), dilation=dilation, groups=inp)
        self.pointwise = nn.Conv2d(inp, outp, 1, 1)

    def forward(self, x):
        x = F.relu6(self.depthwise(x))
        x = F.relu6(self.pointwise(x))
        return x


MOBILENET_V1_CHECKPOINTS = {
    50: 'mobilenet_v1_050',
    75: 'mobilenet_v1_075',
    100: 'mobilenet_v1_100',
    101: 'mobilenet_v1_101'
}

MOBILE_NET_V1_100 = [
    (InputConv, 3, 32, 2),
    (SeperableConv, 32, 64, 1),
    (SeperableConv, 64, 128, 2),
    (SeperableConv, 128, 128, 1),
    (SeperableConv, 128, 256, 2),
    (SeperableConv, 256, 256, 1),
    (SeperableConv, 256, 512, 2),
    (SeperableConv, 512, 512, 1),
    (SeperableConv, 512, 512, 1),
    (SeperableConv, 512, 512, 1),
    (SeperableConv, 512, 512, 1),
    (SeperableConv, 512, 512, 1),
    (SeperableConv, 512, 1024, 2),
    (SeperableConv, 1024, 1024, 1)
]

MOBILE_NET_V1_75 = [
    (InputConv, 3, 24, 2),
    (SeperableConv, 24, 48, 1),
    (SeperableConv, 48, 96, 2),
    (SeperableConv, 96, 96, 1),
    (SeperableConv, 96, 192, 2),
    (SeperableConv, 192, 192, 1),
    (SeperableConv, 192, 384, 2),
    (SeperableConv, 384, 384, 1),
    (SeperableConv, 384, 384, 1),
    (SeperableConv, 384, 384, 1),
    (SeperableConv, 384, 384, 1),
    (SeperableConv, 384, 384, 1),
    (SeperableConv, 384, 384, 1),
    (SeperableConv, 384, 384, 1)
]

MOBILE_NET_V1_50 = [
    (InputConv, 3, 16, 2),
    (SeperableConv, 16, 32, 1),
    (SeperableConv, 32, 64, 2),
    (SeperableConv, 64, 64, 1),
    (SeperableConv, 64, 128, 2),
    (SeperableConv, 128, 128, 1),
    (SeperableConv, 128, 256, 2),
    (SeperableConv, 256, 256, 1),
    (SeperableConv, 256, 256, 1),
    (SeperableConv, 256, 256, 1),
    (SeperableConv, 256, 256, 1),
    (SeperableConv, 256, 256, 1),
    (SeperableConv, 256, 256, 1),
    (SeperableConv, 256, 256, 1)
]


class MobileNetV1(nn.Module):

    def __init__(self, model_id, output_stride=16):
        super(MobileNetV1, self).__init__()

        assert model_id in MOBILENET_V1_CHECKPOINTS.keys()
        self.output_stride = output_stride

        if model_id == 50:
            arch = MOBILE_NET_V1_50
        elif model_id == 75:
            arch = MOBILE_NET_V1_75
        else:
            arch = MOBILE_NET_V1_100

        conv_def = _to_output_strided_layers(arch, output_stride)
        conv_list = [('conv%d' % c['block_id'], c['conv_type'](
            c['inp'], c['outp'], 3, stride=c['stride'], dilation=c['rate']))
            for c in conv_def]
        last_depth = conv_def[-1]['outp']

        self.features = nn.Sequential(OrderedDict(conv_list))
        self.heatmap = nn.Conv2d(last_depth, 17, 1, 1)
        self.offset = nn.Conv2d(last_depth, 34, 1, 1)
        self.displacement_fwd = nn.Conv2d(last_depth, 32, 1, 1)
        self.displacement_bwd = nn.Conv2d(last_depth, 32, 1, 1)

    def forward(self, x):
        x = self.features(x)
        heatmap = torch.sigmoid(self.heatmap(x))
        offset = self.offset(x)
        displacement_fwd = self.displacement_fwd(x)
        displacement_bwd = self.displacement_bwd(x)
        return heatmap, offset, displacement_fwd, displacement_bwd