import os
import warnings
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"


import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import Loaders
import Models
import defs
import argparse


torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True

Mean = [0.485, 0.456, 0.406]
Std = [0.229, 0.224, 0.225]

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True

if __name__ == "__main__":    
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("-me", "--maxEpochs", type=int, default=None)
    parser.add_argument("-w", "--Workers", type=int, default=0)
    parser.add_argument("-b", "--BatchSize", type=int, default=32)
    parser.add_argument("-m", "--Model", type=str, default="")
    parser.add_argument("-t", "--train_type", type=str, default="")
    
    args = parser.parse_args()
    
    setup_seed(2023)
    MaxEpochs = 150
        
    if args.maxEpochs is not None:
        MaxEpochs = args.maxEpochs
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    batchSize = args.BatchSize #17 #8, 32
    numWorkers = args.Workers
    
    ckpLossBest = [
        callbk.ModelCheckpoint(monitor="val_loss", filename="MAE-{epoch}-{val_loss:.4f}", save_top_k=5, mode='min'),
        callbk.ModelCheckpoint(monitor="val_acc", filename="MAE-{epoch}-{val_acc:.4f}", save_top_k=5, mode='max'),
        callbk.EarlyStopping(monitor="val_loss", mode="min", patience=10)
    ]
    
    original_crop = torch.nn.Sequential(
        transforms.CenterCrop(300),
        transforms.Resize(224, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.Normalize(Mean, Std)
    ).to(device)
    
    match (args.train_type):
        case "MAE":
            Model = Models.RARP_MAE(backbone=args.Model, lr=0.0005 * (batchSize/256))
            TrainDINOTransforms = Loaders.RARP_MAE_Augmentation(
                Init_Resize=(300, 420),
                GloblaCropsScale = (0.4, 1),
                Size = 224, 
                device = device,
                mean = Mean,
                std = Std
            )
        case "DINO":
            Model = Models.RARP_Encoder_DINO(max_epochs=MaxEpochs, lr=0.0005 * (batchSize/256))
            TrainDINOTransforms = Loaders.RARP_DINO_AugmentationV2(
                Init_Resize=(300, 420),
                GloblaCropsScale = (0.4, 1),
                LocalCropsScale = (0.05, 0.4),
                NumLocalCrops = 6,
                Size = 224, 
                device = device,
                mean = Mean,
                std = Std,
                Tranform_0 = original_crop
            )
        case _:
            raise Exception("No implemented")
        
    
    trainDataset = torchvision.datasets.DatasetFolder(
        "./Dataset_video/",
        loader=defs.load_Img_RBG_tensor_norm,
        extensions="webp",
        transform=TrainDINOTransforms
    )
    
    valDataset = torchvision.datasets.DatasetFolder(
        "./Dataset_Video_Val/",
        loader=defs.load_Img_RBG_tensor_norm,
        extensions="webp",
        transform=original_crop
    )
    
    Train_DataLoader = DataLoader(
        trainDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=True, 
        drop_last=True,
        pin_memory=True,
        persistent_workers=numWorkers>0,
        #prefetch_factor=1
    )
    
    Val_DataLoader = DataLoader(
        valDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=numWorkers>0
    ) if args.train_type != "DINO" else None
    
    print(f"Model Used: {type(Model).__name__}")
    LogFileName = f"{args.Log_Name}" 
    print("Train Phase")        
    trainer = L.Trainer(
        deterministic=True,
        accelerator='gpu', 
        devices=1,
        logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
        callbacks=ckpLossBest,
        max_epochs=MaxEpochs,
        precision=16,
        log_every_n_steps = 100,
    )
    
    trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
    