Newer
Older
Demo-Maker / modules / PytorchSSD / utils / misc.py
@mikado-4410 mikado-4410 on 10 Oct 2024 1 KB 最初のコミット
import time
import torch


def str2bool(s):
    return s.lower() in ("true", "1")


class Timer:
    def __init__(self):
        self.clock = {}

    def start(self, key="default"):
        self.clock[key] = time.time()

    def end(self, key="default"):
        if key not in self.clock:
            raise Exception(f"{key} is not in the clock.")
        interval = time.time() - self.clock[key]
        del self.clock[key]
        return interval


def save_checkpoint(epoch, net_state_dict, optimizer_state_dict, best_score, checkpoint_path, model_path):
    torch.save(
        {"epoch": epoch, "model": net_state_dict, "optimizer": optimizer_state_dict, "best_score": best_score},
        checkpoint_path,
    )
    torch.save(net_state_dict, model_path)


def load_checkpoint(checkpoint_path):
    return torch.load(checkpoint_path)


def freeze_net_layers(net):
    for param in net.parameters():
        param.requires_grad = False


def store_labels(path, labels):
    with open(path, "w") as f:
        f.write("\n".join(labels))