import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import van
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger
import Loaders
import numpy as np
import torchmetrics
import defs
import argparse
from pathlib import Path
from tqdm import tqdm
import pandas as pd

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True
    
class RARP_NVB_Classification_Head(torch.nn.Module):
    def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.activation = activation_fn
        
        if len (layer) == 0:        
            self.head = torch.nn.Linear(in_features, out_features)
        else:
            temp_head = []
            next_input = in_features
            for num in layer:
                temp_head.append(torch.nn.Linear(next_input, num))
                temp_head.append(self.activation)
                temp_head.append(torch.nn.Dropout(0.4))
                next_input = num
                
            temp_head[-1] = torch.nn.Dropout(0.2)
            temp_head.append(torch.nn.Linear(next_input, out_features))
            
            self.head = torch.nn.Sequential(*temp_head)
            del temp_head
    
    def forward(self, x):
        return self.head(x)
    
class RARP_VAN(L.LightningModule):
    def __init__(
        self, 
        van_model:str = "",
        lr:float = 1e-4,
        clasiffier_layers = [],
        lambda_L1:float = 1.31E-04
    ):
        super().__init__()
        
        self.save_hyperparameters(ignore=["van_model"])
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')

        if len(van_model) == 0:
            self.van_encoder = van.van_b2(pretrained=True, num_classes=0)
            print("pre-train ImageNet")
        else:
            self.van_encoder = van.van_b2(pretrained=False, num_classes=0)
            self.van_encoder.load_state_dict(torch.load(van_model))
            
            
        self.image_emb = 512
                       
        self.clasiffier = RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))
        
        self.lossFN_clasiffier = torch.nn.BCEWithLogitsLoss() 
        
    def _calc_L1(self, params):
        l1 = 0
        for p in params:
            l1 += torch.sum(torch.abs(p))
        return self.hparams.lambda_L1 * l1       
        
    def forward(self, data):
        data = data.float()
                
        img_features = self.van_encoder(data)
        pred = self.clasiffier(img_features)
        
        return pred
    
    def _shared_step(self, batch, val_step:bool=False):
        data, label = batch
        label = label.float()
        
        prediction = self(data)
        
        prediction = prediction.flatten()
        predicted_labels = torch.sigmoid(prediction)
        
        loss = self.lossFN_clasiffier(prediction, label)
        
        if not val_step:
            loss += self._calc_L1(self.clasiffier.parameters()) if self.hparams.lambda_L1 is not None else 0
        
        return loss, label, predicted_labels
        
    def training_step(self, batch, batch_idx):
        loss, true_labels, pred_labels = self._shared_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(pred_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, pred_labels = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(pred_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) 
        
        return [optimizer]
    
class RARP_ConvNext(RARP_VAN):
    def __init__(self, model = "", lr = 0.0001, clasiffier_layers=[], lambda_L1 = 0.000131):
        super().__init__("", lr, clasiffier_layers, lambda_L1)
        
        if len(model) == 0:
            self.van_encoder = torchvision.models.convnext_small(weights=torchvision.models.ConvNeXt_Small_Weights.DEFAULT)
            self.van_encoder.classifier[-1] = torch.nn.Identity()
        else:
            self.van_encoder = torchvision.models.convnext_small()
            self.van_encoder.classifier[-1] = torch.nn.Identity()
            self.van_encoder.load_state_dict(torch.load(model))
              
        self.image_emb = 768
        self.clasiffier = RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))
    
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
    parser.add_argument("--Fold", type=int, default=0)
    parser.add_argument("--Workers", type=int, default=5)
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("-p", "--Pre_train", type=str, default="RARP")
    parser.add_argument("-w", "--Weigth", type=str, default="")
    parser.add_argument("-lv", "--Log_version", type=int)
    parser.add_argument("-e", "--Encoder", type=str, default="VAN")
    
    args = parser.parse_args()
    
    Dataset = Loaders.RARP_DatasetCreator(
        "./DataSet_Ando_All_no20Crop",
        FoldSeed=505,
        createFile=True,
        SavePath="./DataSet_AndoAll20_crop",
        Fold=5,
        removeBlackBar=False,
    )
    Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    Dataset.CreateFolds()
            
    setup_seed(2023)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
    Fold = args.Fold
    InitResize=(512, 512)

    batchSize = 8
    numWorkers = args.Workers
    MaxEpochs = 100
    LogFileName = args.Log_Name

    rootFile = Dataset.CVS_File.parent.parent/f"fold_{Fold}"
    checkPtCallback = [
        callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'), 
        callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
    ]
    
    traintransform = torch.nn.Sequential(
        transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC), 
        transforms.RandomRotation(
            degrees=(-15, 15), 
            fill=5
        ),
        transforms.RandomResizedCrop(
                224, 
                scale=(0.4, 1), 
                antialias=True,
                interpolation=transforms.InterpolationMode.BICUBIC
            ),
        transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))
        ], 0.3),        
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)

    valtransform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)
    
    testtransform =  torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)
    
    trainDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/"train"),
        loader = defs.load_file_tensor,
        extensions = "npy",
        transform = traintransform
    )

    valDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/"val"),
        loader = defs.load_file_tensor,
        extensions = "npy",
        transform = valtransform
    )
    
    testDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/"test"),
        loader=defs.load_file_tensor,
        extensions="npy",
        transform=testtransform
    )
        
    Train_DataLoader = DataLoader(
        trainDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=True, 
        drop_last=True,
        pin_memory=True,
        persistent_workers=numWorkers>0
    )
    Val_DataLoader = DataLoader(
        valDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=numWorkers>0
    )
    Test_DataLoader = DataLoader(
        testDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=numWorkers>0
    )
    
    match(args.Phase):
        case "train":
            Model = None
            match (args.Encoder):
                case "VAN":
                    Model = RARP_VAN(args.Weigth if args.Pre_train == "RARP" else "", clasiffier_layers=[], lr=1e-4, lambda_L1=None)#lambda_L1=2.22e-6
                case "ConvNext":
                    Model = RARP_ConvNext(args.Weigth if args.Pre_train == "RARP" else "", clasiffier_layers=[], lr=1e-4, lambda_L1=None)
                        
            trainer = L.Trainer(
                deterministic=True,
                accelerator='gpu', 
                devices=1, 
                logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
                log_every_n_steps=5,   
                callbacks=checkPtCallback,
                max_epochs=MaxEpochs,
            )
            print("Train Phase")
            
            trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
            trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")
        case "eval_all":
            print("Evaluation Phase")
            rows = []
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
            for ckpFile in pathCkptFile.glob("*.ckpt"):
                print(ckpFile.name)
                Model = None
                match (args.Encoder):
                    case "VAN":
                        Model = RARP_VAN.load_from_checkpoint(ckpFile)
                    case "ConvNext":
                        Model = RARP_ConvNext.load_from_checkpoint(ckpFile)
                        
                Model.to(device)
                Model.eval()
                
                Predictions = []
                Labels = []
                
                with torch.no_grad():
                    for data, label in tqdm(iter(Test_DataLoader)):
                        data = data.float().to(device)
                        label = label.to(device)
                        
                        pred = Model(data).flatten()
                        
                        Predictions.append(torch.sigmoid(pred))
                        Labels.append(label)
                        
                Predictions = torch.cat(Predictions)
                Labels = torch.cat(Labels)
                
                print(Predictions, Labels)

                acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
                precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
                recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
                auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
                f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
                specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
                    
                table = [
                    ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
                ]
                
                for i in range(2):
                    aucCurve = torchmetrics.ROC("binary").to(device)
                    fpr, tpr, thhols = aucCurve(Predictions, Labels)
                    index = torch.argmax(tpr - fpr)
                    th2 = (recall + specificty - 1).item()
                    th2 = 0.5 if th2 <= 0 else th2
                    th1 = thhols[index].item() if i == 0 else th2
                    accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
                    precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
                    recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
                    specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
                    f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
                    #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
                    #cm2.update(Predictions, Labels)
                    #_, ax = cm2.plot()
                    #ax.set_title(f"NVB Classifier (th={th1:.4f})")
                    table.append([
                        f"{th1:.4f}", 
                        f"{accY.item():.4f}", 
                        f"{precisionY.item():.4f}",
                        f"{recallY.item():.4f}", 
                        f"{f1ScoreY.item():.4f}", 
                        f"{auc.item():.4f}", 
                        f"{specifictyY.item():.4f}", 
                        ckpFile.name
                    ])
                
                rows += table
                
            df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])        
            df.style.highlight_max(color="red", axis=0)
            print(df)