import torch.nn as nn
import torchvision.models as models
from torchvision.models import MobileNet_V2_Weights


class RegressionMobileNetV2(nn.Module):
    def __init__(self, pretrained=True):
        super(RegressionMobileNetV2, self).__init__()

        # Load pretrained MobileNetV2
        if pretrained:
            self.model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
        else:
            self.model = models.mobilenet_v2(weights=None)

        # Get the number of features from the last layer
        num_features = self.model.classifier[1].in_features

        # Replace the classifier with a new one for regression
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.2), nn.Linear(num_features, 2)
        )

    def forward(self, x):
        return self.model(x)
