import torch.nn as nn
from torchvision.models import (
EfficientNet_B0_Weights,
EfficientNet_B1_Weights,
EfficientNet_B2_Weights,
EfficientNet_B3_Weights,
EfficientNet_B4_Weights,
EfficientNet_B5_Weights,
EfficientNet_B6_Weights,
EfficientNet_B7_Weights,
efficientnet_b0,
efficientnet_b1,
efficientnet_b2,
efficientnet_b3,
efficientnet_b4,
efficientnet_b5,
efficientnet_b6,
efficientnet_b7,
)
class RegressionEfficientNet(nn.Module):
def __init__(self, efficientnet_version):
super(RegressionEfficientNet, self).__init__()
if efficientnet_version == "b0":
self.model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
elif efficientnet_version == "b1":
self.model = efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V1)
elif efficientnet_version == "b2":
self.model = efficientnet_b2(weights=EfficientNet_B2_Weights.IMAGENET1K_V1)
elif efficientnet_version == "b3":
self.model = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
elif efficientnet_version == "b4":
self.model = efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1)
elif efficientnet_version == "b5":
self.model = efficientnet_b5(weights=EfficientNet_B5_Weights.IMAGENET1K_V1)
elif efficientnet_version == "b6":
self.model = efficientnet_b6(weights=EfficientNet_B6_Weights.IMAGENET1K_V1)
elif efficientnet_version == "b7":
self.model = efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1)
else:
raise ValueError("Invalid EfficientNet version. Choose from 'b0' to 'b7'.")
# Modify the final fully connected layer
num_features = self.model.classifier[1].in_features
self.model.classifier = nn.Sequential(
nn.Dropout(p=0.2, inplace=True), nn.Linear(num_features, 2)
)
def forward(self, x):
return self.model(x)