Newer
Older
RARP / defs.py
@delAguila delAguila on 22 Nov 2024 832 bytes init comit
import numpy as np
import torch
import cv2

def load_file(path):
    return np.load(path).astype(float)

def load_file_tensor(path):
    return torch.from_numpy(np.load(path).astype(float).transpose((2, 0, 1)))

def load_file_tensor_norm(path):
    return torch.from_numpy(np.load(path).astype(float).transpose((2, 0, 1))) / 255.0

def load_Img(path):
    cv2.imread(str(path), cv2.IMREAD_COLOR)
    
def clip_gradients(model, clip=2.0):
    """Rescale norm of computed gradients.

    Parameters
    ----------
    model : nn.Module
        Module.

    clip : float
        Maximum norm.
    """
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            clip_coef = clip / (param_norm + 1e-6)
            if clip_coef < 1:
                p.grad.data.mul_(clip_coef)