import torch.nn as nn
import torchvision.models as models
from torchvision.models import (
EfficientNet_B0_Weights,
EfficientNet_B1_Weights,
EfficientNet_B2_Weights,
EfficientNet_B3_Weights,
EfficientNet_B4_Weights,
EfficientNet_B5_Weights,
EfficientNet_B6_Weights,
EfficientNet_B7_Weights,
ResNet18_Weights,
ResNet34_Weights,
ResNet50_Weights,
ResNet101_Weights,
ResNet152_Weights,
)
class RegressionModel(nn.Module):
def __init__(self, model_name, model_type="resnet"):
super(RegressionModel, self).__init__()
self.model_type = model_type.lower()
if self.model_type == "resnet":
self.model = self._init_resnet(model_name)
elif self.model_type == "efficientnet":
self.model = self._init_efficientnet(model_name)
else:
raise ValueError(
"Invalid model type. Choose from 'resnet' or 'efficientnet'."
)
# Modify the final fully connected layer
num_features = self._get_num_features()
if self.model_type == "resnet":
self.model.fc = nn.Linear(num_features, 2)
else: # efficientnet
self.model.classifier = nn.Linear(num_features, 2)
def _init_resnet(self, depth):
if depth == "18":
return models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
elif depth == "34":
return models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
elif depth == "50":
return models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
elif depth == "101":
return models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
elif depth == "152":
return models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
else:
raise ValueError("Invalid ResNet depth. Choose from 18, 34, 50, 101, 152.")
def _init_efficientnet(self, version):
if version == "b0":
return models.efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
elif version == "b1":
return models.efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1)
elif version == "b2":
return models.efficientnet_b2(weights=EfficientNet_B2_Weights.IMAGENET1K_V1)
elif version == "b3":
return models.efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
elif version == "b4":
return models.efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1)
elif version == "b5":
return models.efficientnet_b5(weights=EfficientNet_B5_Weights.IMAGENET1K_V1)
elif version == "b6":
return models.efficientnet_b6(weights=EfficientNet_B6_Weights.IMAGENET1K_V1)
elif version == "b7":
return models.efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1)
else:
raise ValueError("Invalid EfficientNet version. Choose from 'b0' to 'b7'.")
def _get_num_features(self):
if self.model_type == "resnet":
return self.model.fc.in_features
else: # efficientnet
return self.model.classifier[1].in_features
def forward(self, x):
return self.model(x)