Newer
Older
Demo-Maker / modules / EARSForDL / SqueezeNet.py
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import SqueezeNet1_0_Weights, SqueezeNet1_1_Weights


class RegressionSqueezeNet(nn.Module):
    def __init__(self, version="1_0"):
        super(RegressionSqueezeNet, self).__init__()
        if version == "1_0":
            self.model = models.squeezenet1_0(weights=SqueezeNet1_0_Weights.IMAGENET1K_V1)
        elif version == "1_1":
            self.model = models.squeezenet1_1(weights=SqueezeNet1_1_Weights.IMAGENET1K_V1)
        else:
            raise ValueError("Invalid SqueezeNet version. Choose from '1_0' or '1_1'.")

        # Remove the original classifier
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.5), nn.Conv2d(512, 2, kernel_size=1), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
        )

        # Initialize the new classifier weights
        for m in self.model.classifier:
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.model(x)
        return x.view(x.size(0), -1)