import Models as M
import Loaders
import defs
import numpy as np
import lightning.pytorch as pl
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from packaging import version
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

if version.parse(pl.__version__) < version.parse("1.6.0"):
    raise RuntimeError("PyTorch Lightning>=1.6.0 is required for this example.")
else:
    print("OK")
    
    
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

Dataset = Loaders.RARP_DatasetCreator(
    "./DataSet_Ando_All_Crop",
    FoldSeed=505,
    createFile=True,
    SavePath="./DataSet_Ando_Crop",
    Fold=5,
    removeBlackBar=False,
)

Dataset.CreateFolds()

Fold = 0

    
Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])

setup_seed(2023)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

InitResize = (256,256)
ImgResize = (224, 224)

rootFile = (Dataset.CVS_File.parent.parent/f"fold_{Fold}")
print (rootFile)

traintransform = torch.nn.Sequential(
    transforms.Resize(InitResize, antialias=True),                                   
    transforms.RandomCrop(ImgResize),
    transforms.RandomAffine(
        degrees=(-5, 5), scale=(0.9, 1.1), 
        fill=5
    ),
    transforms.RandomHorizontalFlip(0.9),
    transforms.Normalize(Dataset.mean, Dataset.std),
).to(device)

valtransform = torch.nn.Sequential( 
    transforms.Resize(ImgResize, antialias=True),                                     
    transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)

testtransform =  torch.nn.Sequential(
    transforms.Resize(ImgResize, antialias=True),
    transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)

trainDataset = torchvision.datasets.DatasetFolder(
    str (rootFile/"train"),
    loader=defs.load_file_tensor,
    extensions="npy",
    transform=traintransform
)

valDataset = torchvision.datasets.DatasetFolder(
    str (rootFile/"val"),
    loader=defs.load_file_tensor,
    extensions="npy",
    transform=valtransform
)

testDataset = torchvision.datasets.DatasetFolder(
    str (rootFile/"test"),
    loader=defs.load_file_tensor,
    extensions="npy",
    transform=testtransform
)        


PERCENT_VALID_EXAMPLES = 0.1
EPOCHS = 50

def objective(trail: optuna.trial.Trial) -> float:
    lr = trail.suggest_loguniform("lr", 1e-6, 1e-3)
    l1 = trail.suggest_loguniform("L1", 1e-5, 1e-3)
    n_layers = trail.suggest_int("n_layers", 1, 4)
    dropout = trail.suggest_float("dropout", 0.2, 0.5)
    output_Dims = [
        trail.suggest_int("n_layers_l{}".format(i), 8, 512, log=True) for i in range(n_layers)
    ]
    
    batchSize = trail.suggest_categorical("batch_size", [8, 16, 32])
    numWorkers = 8
    
    model = M.RARP_NVB_ResNet50_Deep_OPTuna(None, M.TypeLossFunction.BCEWithLogits, output_Dims, dropout, config={"lr": lr, "L1": l1})
    
    Train_DataLoader = DataLoader(
        trainDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=True, 
        pin_memory=True,
        persistent_workers=True,
    )

    Val_DataLoader = DataLoader(
        valDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=True,
    )

    Test_DataLoader = DataLoader(
        testDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=True
    )
    
    trainer = pl.Trainer(
        logger=True,
        enable_checkpointing=False,
        max_epochs=EPOCHS,
        accelerator="auto",
        devices=1,
        callbacks=[PyTorchLightningPruningCallback(trail, monitor="val_acc")],
    )
    
    hyperparameters = dict(
        lr = lr,
        l1 = l1,
        n_layers = n_layers, 
        dropout = dropout, 
        output_dims = output_Dims,
        batch_size = batchSize,
    )
    
    trainer.logger.log_hyperparams(hyperparameters)
    
    trainer.fit(model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
    
    return trainer.callback_metrics["val_acc"].item()


if __name__ == "__main__":
    pruning = True

    pruner = optuna.pruners.MedianPruner() if pruning else optuna.pruners.NopPruner()

    study = optuna.create_study(direction="maximize", pruner=pruner)
    study.optimize(objective, n_trials=100, timeout=600)

    print("Number of finished trials: {}".format(len(study.trials)))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: {}".format(trial.value))

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))