import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import SqueezeNet1_0_Weights, SqueezeNet1_1_Weights
class RegressionSqueezeNet(nn.Module):
def __init__(self, version="1_0"):
super(RegressionSqueezeNet, self).__init__()
if version == "1_0":
self.model = models.squeezenet1_0(weights=SqueezeNet1_0_Weights.IMAGENET1K_V1)
elif version == "1_1":
self.model = models.squeezenet1_1(weights=SqueezeNet1_1_Weights.IMAGENET1K_V1)
else:
raise ValueError("Invalid SqueezeNet version. Choose from '1_0' or '1_1'.")
# Remove the original classifier
self.model.classifier = nn.Sequential(
nn.Dropout(p=0.5), nn.Conv2d(512, 2, kernel_size=1), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
)
# Initialize the new classifier weights
for m in self.model.classifier:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.model(x)
return x.view(x.size(0), -1)