from .classification.model.build import EfficientViT_M5
import torch
import torch.nn as nn
class EfficientViT_GSViT(nn.Module):
def __init__(self, pre_trained: str = "EfficientViT_GSViT.pth", force_fp32=False):
super().__init__()
##
#
# A implementation base on the original papaer and repo : https://github.com/SamuelSchmidgall/GSViT.git
#
##
self.force_fp32 = force_fp32
self.evit = EfficientViT_M5(pretrained="efficientvit_m5")
self.evit = nn.Sequential(*list(self.evit.children())[:-1])
if pre_trained:
ckpt = torch.load(pre_trained)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
self.evit.load_state_dict(ckpt, strict=False)
def forward(self, x):
if self.force_fp32 and x.is_cuda:
with torch.cuda.amp.autocast(enabled=False):
x = self.evit(x.float())
else:
x = self.evit(x)
out = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
return out