import torch
import torch.nn as nn
import torchvision.models as models
class ConvBNReLU(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class UpBlock(nn.Module):
def __init__(self, in_ch, skip_ch, out_ch):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
self.conv = ConvBNReLU(in_ch + skip_ch, out_ch)
def forward(self, x, skip):
x = self.up(x)
# サイズ差が出た時の安全策
if x.shape[-2:] != skip.shape[-2:]:
x = nn.functional.interpolate(
x, size=skip.shape[-2:], mode="bilinear", align_corners=False
)
x = torch.cat([x, skip], dim=1)
return self.conv(x)
class ResNet18UNet(nn.Module):
def __init__(self, out_channels=1, pretrained=True):
super().__init__()
base = models.resnet18(
weights=models.ResNet18_Weights.DEFAULT if pretrained else None
)
self.stem = nn.Sequential(base.conv1, base.bn1, base.relu) # /2
self.pool = base.maxpool # /4
self.enc1 = base.layer1 # /4
self.enc2 = base.layer2 # /8
self.enc3 = base.layer3 # /16
self.enc4 = base.layer4 # /32
self.center = ConvBNReLU(512, 512)
self.up4 = UpBlock(512, 256, 256)
self.up3 = UpBlock(256, 128, 128)
self.up2 = UpBlock(128, 64, 64)
self.up1 = UpBlock(64, 64, 64)
self.head = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
s1 = self.stem(x) # (64, H/2, W/2)
s2 = self.pool(s1) # (64, H/4, W/4)
s2 = self.enc1(s2) # (64, H/4, W/4)
s3 = self.enc2(s2) # (128, H/8, W/8)
s4 = self.enc3(s3) # (256, H/16, W/16)
s5 = self.enc4(s4) # (512, H/32, W/32)
x = self.center(s5)
x = self.up4(x, s4)
x = self.up3(x, s3)
x = self.up2(x, s2)
x = self.up1(x, s1)
return self.head(x) # logits (1, H/2?,W/2?) -> 実際はstemが/2なのでH/2