import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import Loaders
import torchmetrics
import matplotlib.pyplot as plt
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
import Models as M
from pathlib import Path
import numpy as np
from tqdm import tqdm
import argparse


torch.backends.cuda.matmul.allow_tf32 = True  
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True
    
def Calc_Eval_table(
    TrainModel:M.RARP_NVB_Model,
    TestDataLoadre:DataLoader, 
    Youden=False, 
    modelName="", 
):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []
     
    with torch.no_grad():
        for data, label in tqdm(iter(TestDataLoadre)):
            
            data = data.float().to(device)
            label = label.to(device)
            
            pred, *_ = TrainModel(data)
            pred = pred.flatten()
                
            Predictions.append(torch.sigmoid(pred))
            Labels.append(label)
      
    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels)

    #print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
        
    table = [
        ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
    ]

    if Youden:
        for i in range(2):
            aucCurve = torchmetrics.ROC("binary").to(device)
            fpr, tpr, thhols = aucCurve(Predictions, Labels)
            index = torch.argmax(tpr - fpr)
            th2 = (recall + specificty - 1).item()
            th2 = 0.5 if th2 <= 0 else th2
            th1 = thhols[index].item() if i == 0 else th2
            accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
            precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
            recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
            specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
            f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
            #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
            #cm2.update(Predictions, Labels)
            #_, ax = cm2.plot()
            #ax.set_title(f"NVB Classifier (th={th1:.4f})")
            table.append([f"{th1:.4f}", f"{accY.item():.4f}", f"{precisionY.item():.4f}", f"{recallY.item():.4f}", f"{f1ScoreY.item():.4f}", f"{auc.item():.4f}", f"{specifictyY.item():.4f}", modelName])
        

    return table
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
    parser.add_argument("--Fold", type=int, default=0) 
    parser.add_argument("-lv","--Log_version", type=int, default=None)
    parser.add_argument("--Workers", type=int, default=0)
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("--Head", type=int, default=None)
    parser.add_argument("-me", "--maxEpochs", type=int, default=None)
    parser.add_argument("-b", "--Batch_size", type=int, default=8)
    parser.add_argument("--GPU", type=int, default=0)
    
    
    args = parser.parse_args()
    
    setup_seed(2023)
    device = torch.device(f"cuda:{args.GPU}" if torch.cuda.is_available() else "cpu")

    df = pd.read_csv("./Dataset_RARP_video/dataset_videos_frames_folds.csv")
    df = df.loc[df["type"] == "f"]

    FOLD = args.Fold
    WORKERS = args.Workers
    BATCH_SIZE = args.Batch_size
    MAX_EPOCHS = 50 if args.maxEpochs is None else args.maxEpochs
    
    Mean = [30.38144216, 42.03988769, 97.8896116]
    Std = [40.63141752, 44.26910074, 50.29294373]
        
    print(f"Fold_{FOLD}")

    train_set = df.loc[df[f"Fold_{FOLD}"] == "train"].sort_values(by=["label", "case"]).to_dict(orient="records")
    val_set = df.loc[df[f"Fold_{FOLD}"] == "val"].sort_values(by=["label", "case"]).to_dict(orient="records")
    test_set = df.loc[df[f"Fold_{FOLD}"] == "test"].sort_values(by=["label", "case"]).to_dict(orient="records")
    
    valtransform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Mean, Std)
    ).to(device)

    testtransform =  torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Mean, Std)
    ).to(device) 
    
    TrainDINOTransforms = Loaders.RARP_DINO_Augmentation(
        GloblaCropsScale = (0.4, 1),
        LocalCropsScale = (0.05, 0.4),
        NumLocalCrops = 4,
        Size = 224, 
        device = device,
        mean = Mean,
        std = Std,
        Tranform_0 = testtransform
    )

    train_dataset = Loaders.RARP_Video_Frames_Dataset(train_set, TrainDINOTransforms, True)
    val_dataset = Loaders.RARP_Video_Frames_Dataset(val_set, valtransform, True)
    test_dataset = Loaders.RARP_Video_Frames_Dataset(test_set, testtransform, True)
    

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        num_workers=WORKERS,
        persistent_workers=WORKERS>0
    )
    
    Model = M.RARP_NVB_DINO_MultiTask_A5_MAE(
        M.TypeLossFunction.BCEWithLogits,
        std=Std,
        mean=Mean,
        L1= 1.31E-04,
        L2= 0,
        lr= 1e-4,
        SoftAdptAlgo=0
    )
    
    print(f"Model Used: {type(Model).__name__}")
    LogFileName = f"{args.Log_Name}" 
    
    checkPtCallback = [
        callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
        callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
    ]
    
    match(args.Phase):
        case "train":
            trainer = L.Trainer(
                deterministic=True,
                accelerator="gpu",
                devices=[args.GPU],
                logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
                log_every_n_steps=5,  
                callbacks=checkPtCallback,
                max_epochs=MAX_EPOCHS
            )
            
            print("Train Phase")
            trainer.fit(Model, train_dataloaders=train_loader, val_dataloaders=val_loader)
            trainer.test(Model, dataloaders=test_loader, ckpt_path="best")
        case "eval_all":
            print("Evaluation Phase")
            rows = []
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
            for ckpFile in pathCkptFile.glob("*.ckpt"):
                print(ckpFile.name)
                Model = M.RARP_NVB_DINO_MultiTask_A5_MAE.load_from_checkpoint(ckpFile)
                
                temp = Calc_Eval_table(
                        Model, 
                        test_loader, 
                        True, 
                        ckpFile.name
                    )
                rows += temp
            
            df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])        
            df.style.highlight_max(color="red", axis=0)
            print(df)