from .classification.model.build import EfficientViT_M5
import torch
import torch.nn as nn
from collections import OrderedDict
class EfficientViT_GSViT(nn.Module):
def _remap_sequential_keys(self, sd):
out = OrderedDict()
for k, v in sd.items():
if k.startswith("module."):
k = k[len("module."):]
parts = k.split(".")
if parts[0].isdigit():
idx = int(parts[0])
if idx not in self.index_to_name:
continue # drop anything you don't want
parts[0] = self.index_to_name[idx]
k = ".".join(parts)
out[k] = v
return out
def __init__(self, pre_trained: str = "EfficientViT_GSViT.pth", force_oo_model=True, device=None):
super().__init__()
##
#
# A implementation base on the original papaer and repo : https://github.com/SamuelSchmidgall/GSViT.git
#
##
self._force = force_oo_model
self.evit = EfficientViT_M5(pretrained="efficientvit_m5")
self.evit.head = nn.Identity()
self.device = device
if pre_trained:
ckpt = torch.load(pre_trained, map_location="cpu" if self.device is None else self.device)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
if force_oo_model:
self.index_to_name = {k:named_C[0] for k, named_C in enumerate(self.evit.named_children())}
ckpt = self._remap_sequential_keys(ckpt)
self.evit.load_state_dict(ckpt, strict=False)
def forward(self, x):
out = self.evit(x)
#out = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
return out