import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import (
    ResNet18_Weights,
    ResNet34_Weights,
    ResNet50_Weights,
    ResNet101_Weights,
    ResNet152_Weights,
)


class RegressionResNet(nn.Module):
    def __init__(self, resnet_depth):
        super(RegressionResNet, self).__init__()
        if resnet_depth == 18:
            self.model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        elif resnet_depth == 34:
            self.model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        elif resnet_depth == 50:
            self.model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        elif resnet_depth == 101:
            self.model = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
        elif resnet_depth == 152:
            self.model = models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
        else:
            raise ValueError("Invalid ResNet depth. Choose from 18, 34, 50, 101, 152.")

        # Modify the final fully connected layer
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, 2)

    def forward(self, x):
        return self.model(x)
