Newer
Older
Demo-Maker / modules / EARSNet / model.py
@mikado-4410 mikado-4410 on 22 Jan 2025 2 KB [fix]各手法のFPSを計算
import torch.nn as nn
import torchvision.models as models
from torchvision.models import (
    ResNet18_Weights,
    ResNet34_Weights,
    ResNet50_Weights,
    ResNet101_Weights,
    ResNet152_Weights,
)


class RegressionModel(nn.Module):
    def __init__(self, resnet_depth: str, pretrained: bool = True):
        """
        Args:
            resnet_depth (str): "18", "34", "50", "101", or "152"
            pretrained (bool): True if using ImageNet pretrained weights, False for scratch
        """
        super(RegressionModel, self).__init__()
        self.model = self._init_resnet(resnet_depth, pretrained)

        # Modify the final fully connected layer
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, 2)

    def _init_resnet(self, depth: str, pretrained: bool):
        # pretrained=True  => Use ImageNet weights
        # pretrained=False => weights=None (scratch)
        if depth == "18":
            if pretrained:
                return models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
            else:
                return models.resnet18(weights=None)
        elif depth == "34":
            if pretrained:
                return models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
            else:
                return models.resnet34(weights=None)
        elif depth == "50":
            if pretrained:
                return models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
            else:
                return models.resnet50(weights=None)
        elif depth == "101":
            if pretrained:
                return models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
            else:
                return models.resnet101(weights=None)
        elif depth == "152":
            if pretrained:
                return models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
            else:
                return models.resnet152(weights=None)
        else:
            raise ValueError("Invalid ResNet depth. Choose from 18, 34, 50, 101, 152.")

    def forward(self, x):
        return self.model(x)