Newer
Older
Demo-Maker / modules / EARSForDL / EfficientNet.py
import torch.nn as nn
from torchvision.models import (
    EfficientNet_B0_Weights,
    EfficientNet_B1_Weights,
    EfficientNet_B2_Weights,
    EfficientNet_B3_Weights,
    EfficientNet_B4_Weights,
    EfficientNet_B5_Weights,
    EfficientNet_B6_Weights,
    EfficientNet_B7_Weights,
    efficientnet_b0,
    efficientnet_b1,
    efficientnet_b2,
    efficientnet_b3,
    efficientnet_b4,
    efficientnet_b5,
    efficientnet_b6,
    efficientnet_b7,
)


class RegressionEfficientNet(nn.Module):
    def __init__(self, efficientnet_version):
        super(RegressionEfficientNet, self).__init__()

        if efficientnet_version == "b0":
            self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        elif efficientnet_version == "b1":
            self.model = efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1)
        elif efficientnet_version == "b2":
            self.model = efficientnet_b2(weights=EfficientNet_B2_Weights.IMAGENET1K_V1)
        elif efficientnet_version == "b3":
            self.model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
        elif efficientnet_version == "b4":
            self.model = efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1)
        elif efficientnet_version == "b5":
            self.model = efficientnet_b5(weights=EfficientNet_B5_Weights.IMAGENET1K_V1)
        elif efficientnet_version == "b6":
            self.model = efficientnet_b6(weights=EfficientNet_B6_Weights.IMAGENET1K_V1)
        elif efficientnet_version == "b7":
            self.model = efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1)
        else:
            raise ValueError("Invalid EfficientNet version. Choose from 'b0' to 'b7'.")

        # Modify the final fully connected layer
        num_features = self.model.classifier[1].in_features
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.2, inplace=True), nn.Linear(num_features, 2)
        )

    def forward(self, x):
        return self.model(x)