import time
import torch
def str2bool(s):
return s.lower() in ("true", "1")
class Timer:
def __init__(self):
self.clock = {}
def start(self, key="default"):
self.clock[key] = time.time()
def end(self, key="default"):
if key not in self.clock:
raise Exception(f"{key} is not in the clock.")
interval = time.time() - self.clock[key]
del self.clock[key]
return interval
def save_checkpoint(epoch, net_state_dict, optimizer_state_dict, best_score, checkpoint_path, model_path):
torch.save(
{"epoch": epoch, "model": net_state_dict, "optimizer": optimizer_state_dict, "best_score": best_score},
checkpoint_path,
)
torch.save(net_state_dict, model_path)
def load_checkpoint(checkpoint_path):
return torch.load(checkpoint_path)
def freeze_net_layers(net):
for param in net.parameters():
param.requires_grad = False
def store_labels(path, labels):
with open(path, "w") as f:
f.write("\n".join(labels))