import os
import warnings
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
warnings.simplefilter("ignore")
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import Loaders
import Models
import defs
import argparse
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True
Mean = [0.485, 0.456, 0.406]
Std = [0.229, 0.224, 0.225]
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
parser.add_argument("-me", "--maxEpochs", type=int, default=None)
parser.add_argument("-w", "--Workers", type=int, default=0)
parser.add_argument("-b", "--BatchSize", type=int, default=16)
args = parser.parse_args()
setup_seed(2023)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batchSize = args.BatchSize #17 #8, 32
numWorkers = args.Workers
ckpLossBest = [
callbk.ModelCheckpoint(monitor="train_loss", filename="DINO-{epoch}-{train_loss:.4f}", save_top_k=5, mode='min'),
#callbk.ModelCheckpoint(monitor="val_silhouette_teacher", filename="DINO_T-{epoch}-{val_silhouette_teacher:.4f}", save_top_k=5, mode='max'),
#callbk.ModelCheckpoint(monitor="val_silhouette_student", filename="DINO_S-{epoch}-{val_silhouette_student:.4f}", save_top_k=5, mode='max'),
callbk.ModelCheckpoint(monitor="val_acc", filename="DINO_S-{epoch}-{val_acc:.4f}", save_top_k=5, mode='max'),
]
default_transform = torch.nn.Sequential(
transforms.CenterCrop(300),
transforms.Resize(224, transforms.InterpolationMode.BICUBIC),
transforms.Normalize(Mean, Std)
).to(device)
TrainDINOTransforms = Loaders.RARP_DINO_Augmentation(
Init_Resize=(300, 420),
GloblaCropsScale = (0.4, 1),
LocalCropsScale = (0.05, 0.4),
NumLocalCrops = 6,
Size = 224,
device = device,
mean = Mean,
std = Std,
Tranform_0 = default_transform
)
trainDataset = torchvision.datasets.DatasetFolder(
"./Dataset_video/",
loader=defs.load_Img_RBG_tensor_norm,
extensions="webp",
transform=TrainDINOTransforms
)
valDataset = torchvision.datasets.DatasetFolder(
"./Dataset_video_Val/",
loader=defs.load_Img_RBG_tensor_norm,
extensions="webp",
transform=default_transform
)
Train_DataLoader = DataLoader(
trainDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=True,
drop_last=True,
pin_memory=False,
persistent_workers=numWorkers>0,
prefetch_factor=1
)
Val_DataLoader = DataLoader(
valDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
MaxEpochs = 150
if args.maxEpochs is not None:
MaxEpochs = args.maxEpochs
#Model = Models.RARP_Encoder_DINO(max_epochs=MaxEpochs, lr=0.0005 * (batchSize/256), total_steps=MaxEpochs*2966 )
Model = Models.RARP_Encoder_DINO_AUX_task(lr=0.0005 * (batchSize/256), aux_lambda=0.4)
print(f"Model Used: {type(Model).__name__}")
LogFileName = f"{args.Log_Name}"
print("Train Phase")
trainer = L.Trainer(
deterministic=True,
accelerator='gpu',
devices=2,
strategy="ddp",
logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
callbacks=ckpLossBest,
max_epochs=MaxEpochs,
log_every_n_steps = 100,
precision=16,
#gradient_clip_val=0.3,
#gradient_clip_algorithm="norm"
)
trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)