import os
import warnings
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import Loaders
import Models
import defs
import argparse
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True
Mean = [0.485, 0.456, 0.406]
Std = [0.229, 0.224, 0.225]
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
parser.add_argument("-me", "--maxEpochs", type=int, default=None)
parser.add_argument("-w", "--Workers", type=int, default=0)
parser.add_argument("-b", "--BatchSize", type=int, default=32)
parser.add_argument("-m", "--Model", type=str, default="")
parser.add_argument("-t", "--train_type", type=str, default="")
args = parser.parse_args()
setup_seed(2023)
MaxEpochs = 150
if args.maxEpochs is not None:
MaxEpochs = args.maxEpochs
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batchSize = args.BatchSize #17 #8, 32
numWorkers = args.Workers
ckpLossBest = [
callbk.ModelCheckpoint(monitor="val_loss", filename="MAE-{epoch}-{val_loss:.4f}", save_top_k=5, mode='min'),
callbk.ModelCheckpoint(monitor="val_acc", filename="MAE-{epoch}-{val_acc:.4f}", save_top_k=5, mode='max'),
callbk.EarlyStopping(monitor="val_loss", mode="min", patience=10)
]
original_crop = torch.nn.Sequential(
transforms.CenterCrop(300),
transforms.Resize(224, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(Mean, Std)
).to(device)
match (args.train_type):
case "MAE":
Model = Models.RARP_MAE(backbone=args.Model, lr=0.0005 * (batchSize/256))
TrainDINOTransforms = Loaders.RARP_MAE_Augmentation(
Init_Resize=(300, 420),
GloblaCropsScale = (0.4, 1),
Size = 224,
device = device,
mean = Mean,
std = Std
)
case "DINO":
Model = Models.RARP_Encoder_DINO(max_epochs=MaxEpochs, lr=0.0005 * (batchSize/256))
TrainDINOTransforms = Loaders.RARP_DINO_AugmentationV2(
Init_Resize=(300, 420),
GloblaCropsScale = (0.4, 1),
LocalCropsScale = (0.05, 0.4),
NumLocalCrops = 6,
Size = 224,
device = device,
mean = Mean,
std = Std,
Tranform_0 = original_crop
)
case _:
raise Exception("No implemented")
trainDataset = torchvision.datasets.DatasetFolder(
"./Dataset_video/",
loader=defs.load_Img_RBG_tensor_norm,
extensions="webp",
transform=TrainDINOTransforms
)
valDataset = torchvision.datasets.DatasetFolder(
"./Dataset_Video_Val/",
loader=defs.load_Img_RBG_tensor_norm,
extensions="webp",
transform=original_crop
)
Train_DataLoader = DataLoader(
trainDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=True,
drop_last=True,
pin_memory=True,
persistent_workers=numWorkers>0,
#prefetch_factor=1
)
Val_DataLoader = DataLoader(
valDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
) if args.train_type != "DINO" else None
print(f"Model Used: {type(Model).__name__}")
LogFileName = f"{args.Log_Name}"
print("Train Phase")
trainer = L.Trainer(
deterministic=True,
accelerator='gpu',
devices=1,
logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
callbacks=ckpLossBest,
max_epochs=MaxEpochs,
precision=16,
log_every_n_steps = 100,
)
trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)