import torch
from torch.nn import Parameter


def inflate_conv(
        conv2d,
        time_dim=3,
        time_padding=0,
        time_stride=1,
        time_dilation=1,
        center=False
    ):
    # To preserve activations, padding should be by continuity and not zero
    # or no padding in time dimension
    kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])
    padding = (time_padding, conv2d.padding[0], conv2d.padding[1])
    stride = (time_stride, conv2d.stride[0], conv2d.stride[0])
    dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1])
    conv3d = torch.nn.Conv3d(
        conv2d.in_channels,
        conv2d.out_channels,
        kernel_dim,
        padding=padding,
        dilation=dilation,
        stride=stride)
    # Repeat filter time_dim times along time dimension
    weight_2d = conv2d.weight.data
    if center:
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
        middle_idx = time_dim // 2
        weight_3d[:, :, middle_idx, :, :] = weight_2d
    else:
        weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
        weight_3d = weight_3d / time_dim

    # Assign new params
    conv3d.weight = Parameter(weight_3d)
    conv3d.bias = conv2d.bias
    return conv3d

def inflate_linear(linear2d, time_dim):
    """
    Args:
        time_dim: final time dimension of the features
    """
    linear3d = torch.nn.Linear(linear2d.in_features * time_dim,
                               linear2d.out_features)
    weight3d = linear2d.weight.data.repeat(1, time_dim)
    weight3d = weight3d / time_dim

    linear3d.weight = Parameter(weight3d)
    linear3d.bias = linear2d.bias
    return linear3d

def inflate_batch_norm(batch2d):
    # In pytorch 0.2.0 the 2d and 3d versions of batch norm
    # work identically except for the check that verifies the
    # input dimensions

    batch3d = torch.nn.BatchNorm3d(batch2d.num_features)
    # retrieve 3d _check_input_dim function
    batch2d._check_input_dim = batch3d._check_input_dim
    return batch2d

def inflate_pool2d_det(pool2d: torch.nn.MaxPool2d,
                       time_kernel=1, time_stride=1, time_pad=0):
    # unpack tuples safely
    kH, kW   = pool2d.kernel_size if isinstance(pool2d.kernel_size, tuple) else (pool2d.kernel_size,)*2
    sH, sW   = pool2d.stride      if isinstance(pool2d.stride, tuple)      else (pool2d.stride,)*2
    pH, pW   = pool2d.padding     if isinstance(pool2d.padding, tuple)     else (pool2d.padding,)*2

    # NOTE: AvgPool3d has NO dilation arg; if your MaxPool2d used dilation>1,
    # use Option 2 below instead.
    return torch.nn.AvgPool3d(
        kernel_size=(time_kernel, kH, kW),
        stride=(time_stride, sH, sW),
        padding=(time_pad, pH, pW),
        ceil_mode=pool2d.ceil_mode,
        count_include_pad=False,
    )

def inflate_pool(
    pool2d,
    time_dim=1,
    time_padding=0,
    time_stride=None,
    time_dilation=1
):
    if isinstance(pool2d, torch.nn.AdaptiveAvgPool2d):
        pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
    else:
        kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size)
        padding = (time_padding, pool2d.padding, pool2d.padding)
        if time_stride is None:
            time_stride = time_dim
        stride = (time_stride, pool2d.stride, pool2d.stride)
        if isinstance(pool2d, torch.nn.MaxPool2d):
            dilation = (time_dilation, pool2d.dilation, pool2d.dilation)
            pool3d = torch.nn.MaxPool3d(
                kernel_dim,
                padding=padding,
                dilation=dilation,
                stride=stride,
                ceil_mode=pool2d.ceil_mode)
            #pool3d = inflate_pool2d_det(pool2d, time_kernel=time_dilation, time_stride=1, time_pad=0)
        elif isinstance(pool2d, torch.nn.AvgPool2d):
            pool3d = torch.nn.AvgPool3d(kernel_dim, stride=stride)
        else:
            raise ValueError('{} is not among known pooling classes'.format(type(pool2d)))

    return pool3d