Newer
Older
RARP / EfficientViT / GSViT.py
@delAguila delAguila on 8 Jan 1 KB GSViT
from .classification.model.build import EfficientViT_M5
import torch
import torch.nn as nn

class EfficientViT_GSViT(nn.Module):
    def __init__(self, pre_trained: str = "EfficientViT_GSViT.pth", force_fp32=False):
        super().__init__()
        
        ## 
        #
        #  A implementation base on the original papaer and repo : https://github.com/SamuelSchmidgall/GSViT.git
        #
        ##
        
        self.force_fp32 = force_fp32
        self.evit = EfficientViT_M5(pretrained="efficientvit_m5")
        self.evit = nn.Sequential(*list(self.evit.children())[:-1])
        
        if pre_trained:
            ckpt = torch.load(pre_trained, map_location="cpu")
            if isinstance(ckpt, dict) and "state_dict" in ckpt:
                ckpt = ckpt["state_dict"]
            self.evit.load_state_dict(ckpt, strict=False)
            
    def forward(self, x):
        if self.force_fp32 and x.is_cuda:
            with torch.cuda.amp.autocast(enabled=False):
                x = self.evit(x.float())
        else:
            x = self.evit(x)
        out = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
        return out