import torch
import torch.nn as nn
import torchvision.models as models


class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.conv = ConvBNReLU(in_ch + skip_ch, out_ch)

    def forward(self, x, skip):
        x = self.up(x)
        # サイズ差が出た時の安全策
        if x.shape[-2:] != skip.shape[-2:]:
            x = nn.functional.interpolate(
                x, size=skip.shape[-2:], mode="bilinear", align_corners=False
            )
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)


class ResNet18UNet(nn.Module):
    def __init__(self, out_channels=1, pretrained=True):
        super().__init__()
        base = models.resnet18(
            weights=models.ResNet18_Weights.DEFAULT if pretrained else None
        )

        self.stem = nn.Sequential(base.conv1, base.bn1, base.relu)  # /2
        self.pool = base.maxpool  # /4
        self.enc1 = base.layer1  # /4
        self.enc2 = base.layer2  # /8
        self.enc3 = base.layer3  # /16
        self.enc4 = base.layer4  # /32

        self.center = ConvBNReLU(512, 512)

        self.up4 = UpBlock(512, 256, 256)
        self.up3 = UpBlock(256, 128, 128)
        self.up2 = UpBlock(128, 64, 64)
        self.up1 = UpBlock(64, 64, 64)

        self.head = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        s1 = self.stem(x)  # (64, H/2, W/2)
        s2 = self.pool(s1)  # (64, H/4, W/4)
        s2 = self.enc1(s2)  # (64, H/4, W/4)
        s3 = self.enc2(s2)  # (128, H/8, W/8)
        s4 = self.enc3(s3)  # (256, H/16, W/16)
        s5 = self.enc4(s4)  # (512, H/32, W/32)

        x = self.center(s5)
        x = self.up4(x, s4)
        x = self.up3(x, s3)
        x = self.up2(x, s2)
        x = self.up1(x, s1)

        return self.head(x)  # logits (1, H/2?,W/2?) -> 実際はstemが/2なのでH/2
