Newer
Older
RARP_server / EfficientViT / GSViT_RARP.py
from .classification.model.build import EfficientViT_M5
import torch
import torch.nn as nn
from collections import OrderedDict

class EfficientViT_GSViT(nn.Module):
    def _remap_sequential_keys(self, sd):
        out = OrderedDict()
        for k, v in sd.items():
            if k.startswith("module."):
                k = k[len("module."):]
            parts = k.split(".")
            if parts[0].isdigit():
                idx = int(parts[0])
                if idx not in self.index_to_name:
                    continue  # drop anything you don't want
                parts[0] = self.index_to_name[idx]
                k = ".".join(parts)
            out[k] = v
        return out
    
    def __init__(self, pre_trained: str = "EfficientViT_GSViT.pth", force_oo_model=True):
        super().__init__()
        
        ## 
        #
        #  A implementation base on the original papaer and repo : https://github.com/SamuelSchmidgall/GSViT.git
        #
        ##
        
        self._force = force_oo_model
        self.evit = EfficientViT_M5(pretrained="efficientvit_m5")
        self.evit.head = nn.Identity()
                
        if pre_trained:
            ckpt = torch.load(pre_trained, map_location="cpu")
            if isinstance(ckpt, dict) and "state_dict" in ckpt:
                ckpt = ckpt["state_dict"]
            if force_oo_model:
                self.index_to_name = {k:named_C[0] for k, named_C in  enumerate(self.evit.named_children())}
                ckpt = self._remap_sequential_keys(ckpt)
                
            self.evit.load_state_dict(ckpt, strict=False)
            
    def forward(self, x):
        out = self.evit(x)
        #out = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
        return out