import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import (
ResNet18_Weights,
ResNet34_Weights,
ResNet50_Weights,
ResNet101_Weights,
ResNet152_Weights,
)
class RegressionResNet(nn.Module):
def __init__(self, resnet_depth):
super(RegressionResNet, self).__init__()
if resnet_depth == 18:
self.model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
elif resnet_depth == 34:
self.model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
elif resnet_depth == 50:
self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
elif resnet_depth == 101:
self.model = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
elif resnet_depth == 152:
self.model = models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
else:
raise ValueError("Invalid ResNet depth. Choose from 18, 34, 50, 101, 152.")
# Modify the final fully connected layer
num_features = self.model.fc.in_features
self.model.fc = nn.Linear(num_features, 2)
def forward(self, x):
return self.model(x)