Newer
Older
RARP_server / RARP_Encoder_train.py
import os
import warnings
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
warnings.simplefilter("ignore")

import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import Loaders
import Models
import defs
import argparse

torch.backends.cuda.matmul.allow_tf32 = True  
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True

Mean = [0.485, 0.456, 0.406]
Std = [0.229, 0.224, 0.225]

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True

if __name__ == "__main__":    
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("-me", "--maxEpochs", type=int, default=None)
    parser.add_argument("-w", "--Workers", type=int, default=0)
    parser.add_argument("-b", "--BatchSize", type=int, default=16)
    
    args = parser.parse_args()
    
    setup_seed(2023)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    batchSize = args.BatchSize #17 #8, 32
    numWorkers = args.Workers
    
    ckpLossBest = [
        callbk.ModelCheckpoint(monitor="train_loss", filename="DINO-{epoch}-{train_loss:.4f}", save_top_k=5, mode='min'),
        #callbk.ModelCheckpoint(monitor="val_silhouette_teacher", filename="DINO_T-{epoch}-{val_silhouette_teacher:.4f}", save_top_k=5, mode='max'),
        #callbk.ModelCheckpoint(monitor="val_silhouette_student", filename="DINO_S-{epoch}-{val_silhouette_student:.4f}", save_top_k=5, mode='max'),
        callbk.ModelCheckpoint(monitor="val_acc", filename="DINO_S-{epoch}-{val_acc:.4f}", save_top_k=5, mode='max'),
    ]
    
    default_transform = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize(224, transforms.InterpolationMode.BICUBIC),
        transforms.Normalize(Mean, Std)
    ).to(device)
    
    TrainDINOTransforms = Loaders.RARP_DINO_Augmentation(
        Init_Resize=(300, 420),
        GloblaCropsScale = (0.4, 1),
        LocalCropsScale = (0.05, 0.4),
        NumLocalCrops = 6,
        Size = 224, 
        device = device,
        mean = Mean,
        std = Std,
        Tranform_0 = default_transform
    )
    
    trainDataset = torchvision.datasets.DatasetFolder(
        "./Dataset_video/",
        loader=defs.load_Img_RBG_tensor_norm,
        extensions="webp",
        transform=TrainDINOTransforms
    )
    
    valDataset = torchvision.datasets.DatasetFolder(
        "./Dataset_video_Val/",
        loader=defs.load_Img_RBG_tensor_norm,
        extensions="webp",
        transform=default_transform
    )
    
    Train_DataLoader = DataLoader(
        trainDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=True, 
        drop_last=True,
        pin_memory=False,
        persistent_workers=numWorkers>0,
        prefetch_factor=1
    )
    
    Val_DataLoader = DataLoader(
        valDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=numWorkers>0
    )
    
    MaxEpochs = 150
        
    if args.maxEpochs is not None:
        MaxEpochs = args.maxEpochs
    
    #Model = Models.RARP_Encoder_DINO(max_epochs=MaxEpochs, lr=0.0005 * (batchSize/256), total_steps=MaxEpochs*2966 )
    Model = Models.RARP_Encoder_DINO_AUX_task(lr=0.0005 * (batchSize/256), aux_lambda=0.4)
    
    print(f"Model Used: {type(Model).__name__}")
    LogFileName = f"{args.Log_Name}" 
    print("Train Phase")        
    trainer = L.Trainer(
        deterministic=True,
        accelerator='gpu', 
        devices=2,
        strategy="ddp",
        logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
        callbacks=ckpLossBest,
        max_epochs=MaxEpochs,
        log_every_n_steps = 100,
        precision=16,
        #gradient_clip_val=0.3,
        #gradient_clip_algorithm="norm" 
    )
    
    trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)