Newer
Older
Demo-Maker / modules / EARSNet / model.py
import torch.nn as nn
import torchvision.models as models
from torchvision.models import (
    EfficientNet_B0_Weights,
    EfficientNet_B1_Weights,
    EfficientNet_B2_Weights,
    EfficientNet_B3_Weights,
    EfficientNet_B4_Weights,
    EfficientNet_B5_Weights,
    EfficientNet_B6_Weights,
    EfficientNet_B7_Weights,
    ResNet18_Weights,
    ResNet34_Weights,
    ResNet50_Weights,
    ResNet101_Weights,
    ResNet152_Weights,
)


class RegressionModel(nn.Module):
    def __init__(self, model_name, model_type="resnet"):
        super(RegressionModel, self).__init__()

        self.model_type = model_type.lower()

        if self.model_type == "resnet":
            self.model = self._init_resnet(model_name)
        elif self.model_type == "efficientnet":
            self.model = self._init_efficientnet(model_name)
        else:
            raise ValueError(
                "Invalid model type. Choose from 'resnet' or 'efficientnet'."
            )

        # Modify the final fully connected layer
        num_features = self._get_num_features()
        if self.model_type == "resnet":
            self.model.fc = nn.Linear(num_features, 2)
        else:  # efficientnet
            self.model.classifier = nn.Linear(num_features, 2)

    def _init_resnet(self, depth):
        if depth == "18":
            return models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        elif depth == "34":
            return models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        elif depth == "50":
            return models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        elif depth == "101":
            return models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
        elif depth == "152":
            return models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
        else:
            raise ValueError("Invalid ResNet depth. Choose from 18, 34, 50, 101, 152.")

    def _init_efficientnet(self, version):
        if version == "b0":
            return models.efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        elif version == "b1":
            return models.efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1)
        elif version == "b2":
            return models.efficientnet_b2(weights=EfficientNet_B2_Weights.IMAGENET1K_V1)
        elif version == "b3":
            return models.efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
        elif version == "b4":
            return models.efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1)
        elif version == "b5":
            return models.efficientnet_b5(weights=EfficientNet_B5_Weights.IMAGENET1K_V1)
        elif version == "b6":
            return models.efficientnet_b6(weights=EfficientNet_B6_Weights.IMAGENET1K_V1)
        elif version == "b7":
            return models.efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1)
        else:
            raise ValueError("Invalid EfficientNet version. Choose from 'b0' to 'b7'.")

    def _get_num_features(self):
        if self.model_type == "resnet":
            return self.model.fc.in_features
        else:  # efficientnet
            return self.model.classifier[1].in_features

    def forward(self, x):
        return self.model(x)