import torch.nn as nn
import torchvision.models as models
from torchvision.models import MobileNet_V2_Weights
class RegressionMobileNetV2(nn.Module):
def __init__(self, pretrained=True):
super(RegressionMobileNetV2, self).__init__()
# Load pretrained MobileNetV2
if pretrained:
self.model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
else:
self.model = models.mobilenet_v2(weights=None)
# Get the number of features from the last layer
num_features = self.model.classifier[1].in_features
# Replace the classifier with a new one for regression
self.model.classifier = nn.Sequential(
nn.Dropout(p=0.2), nn.Linear(num_features, 2)
)
def forward(self, x):
return self.model(x)