import torch.nn as nn
import torchvision.models as models
from torchvision.models import (
ResNet18_Weights,
ResNet34_Weights,
ResNet50_Weights,
ResNet101_Weights,
ResNet152_Weights,
)
class RegressionModel(nn.Module):
def __init__(self, resnet_depth: str, pretrained: bool = True):
"""
Args:
resnet_depth (str): "18", "34", "50", "101", or "152"
pretrained (bool): True if using ImageNet pretrained weights, False for scratch
"""
super(RegressionModel, self).__init__()
self.model = self._init_resnet(resnet_depth, pretrained)
# Modify the final fully connected layer
num_features = self.model.fc.in_features
self.model.fc = nn.Linear(num_features, 2)
def _init_resnet(self, depth: str, pretrained: bool):
# pretrained=True => Use ImageNet weights
# pretrained=False => weights=None (scratch)
if depth == "18":
if pretrained:
return models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
else:
return models.resnet18(weights=None)
elif depth == "34":
if pretrained:
return models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
else:
return models.resnet34(weights=None)
elif depth == "50":
if pretrained:
return models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
else:
return models.resnet50(weights=None)
elif depth == "101":
if pretrained:
return models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
else:
return models.resnet101(weights=None)
elif depth == "152":
if pretrained:
return models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
else:
return models.resnet152(weights=None)
else:
raise ValueError("Invalid ResNet depth. Choose from 18, 34, 50, 101, 152.")
def forward(self, x):
return self.model(x)