import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from transformers import AutoTokenizer, AutoModel
import van
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger
import Models as M
from pathlib import Path
import pandas as pd
import Loaders
import numpy as np
import torchmetrics
import defs
import argparse
from tqdm import tqdm

from softadapt import LossWeightedSoftAdapt, NormalizedSoftAdapt


LLM = "emilyalsentzer/Bio_ClinicalBERT"

PROMPT = [(
    "Post-prostatectomy robotic laparoscopic view of the pelvic surgical bed in a {age}-year-old patient "
    "(BMI {BMI}, PSA {PSA} ng/mL) with clinical stage {cT}, pathologic stage {pT} and Gleason score {GS}; "
    "prostate size {prostate_size} mm; operating time was {surgery_time} min (console time {console_time} min); "
    "blood loss {blood_loss} mL. Neurovascular bundle preserved: {NVB}"
)]


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True

class RARP_DatasetFolder_LLM(torchvision.datasets.DatasetFolder):
    def __init__(self, 
                 root: str, 
                 loader, 
                 Extra_Data: pd.DataFrame,
                 tokenizer = None, 
                 extensions = None, 
                 transform = None, 
                 target_transform = None, 
                 is_valid_file = None
                ) -> None:
        super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)
        self.Extra_Data = Extra_Data
        self.tokenizer = tokenizer 
        
    def __getitem__(self, index: int):
        path, target = self.samples[index]
        
        name = Path(path).name
        Extra_data = self.Extra_Data[self.Extra_Data["raw_name"] == name].fillna('').to_dict("records")[0]
        prompt_text = PROMPT[0].format(**Extra_data)
        
        text_data = self.tokenizer(prompt_text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
        
        for k in text_data:
            text_data[k] = text_data[k].squeeze(0)
        
        sample = self.loader(path) 
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (sample, text_data), target

class RARP_CLIP_loss(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.lossFN = torch.nn.CrossEntropyLoss()
        
    def forward(self, clip_logits:torch.Tensor):
        image_logits, text_logits = clip_logits 
        label = torch.arange(image_logits.size(0), device=image_logits.device)
        
        loss = (self.lossFN(image_logits, label) + self.lossFN(text_logits, label)) / 2
        
        return loss

class RARP_CLIP(nn.Module):
    def __init__(
        self, 
        text_output_feat_dim:int, 
        img_output_feat_dim:int, 
        embed_dim:int, 
        latent_space_dim:int
    ):
        super().__init__()
        
        self.image_latent_space = M.RARP_NVB_MLP(img_output_feat_dim, latent_space_dim, n_layers=2) 
        self.text_latent_space = M.RARP_NVB_MLP(text_output_feat_dim, latent_space_dim, n_layers=2)
        
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1/0.07)))
         
    def forward(self, image_emb, text_emb):
        
        x_img = self.image_latent_space(image_emb)
        x_text = self.text_latent_space(text_emb)
        
        #Normalize
        x_img = F.normalize(x_img, dim=-1)
        x_text = F.normalize(x_text, dim=-1)
        
        #Scaled cosine logits
        scale = self.logit_scale.exp()
        logits_img = scale * x_img @ x_text.t()
        logits_text = logits_img.t()
        
        return logits_img, logits_text
        
class RARP_VAN_BERT(L.LightningModule):
    def __init__(
        self, 
        bert_model_name:str = LLM,
        van_model:str = "",
        lr:float = 1e-4,
        latent_space_dim:int = 512,
        hiden_dim:int = 256,
        clasiffier_layers = [],
        softAdptAlgo:int = 0,
        softAdptBeta:float = 0.1,
        lambda_L1:float = 1.31E-04,
    ):
        super().__init__()
        
        self.save_hyperparameters(ignore=["bert_model_name", "van_model"])
        self.loss_weights = [1, 1]
        
        self.softAdapt = NormalizedSoftAdapt(softAdptBeta) if softAdptAlgo == 1 else LossWeightedSoftAdapt(softAdptBeta)
        self.loss_history = {            
            0 : [], #'loss_Binary'
            1 : [], #'loss_CLIP'
        }
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
        
        self.bert_llm = AutoModel.from_pretrained(bert_model_name)
        self.text_emb = 768
        
        for parms in self.bert_llm.parameters():
            parms.requires_grad = False
        
        self.van_encoder = van.van_b2(pretrained=False, num_classes=0)
        self.van_encoder.load_state_dict(torch.load(van_model))
        self.image_emb = 512
                
        self.clip = RARP_CLIP(self.text_emb, self.image_emb, hiden_dim, latent_space_dim)
        
        self.clasiffier = M.RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))
        
        self.lossFN_clasiffier = torch.nn.BCEWithLogitsLoss()
        self.lossFN_CLIP = RARP_CLIP_loss()
        
        
    def forward(self, data):
        img_data, text_data = data
        img_data = img_data.float()
        
        ids_tokens = text_data["input_ids"]
        att_mask = text_data["attention_mask"]
                
        img_features = self.van_encoder(img_data)
        
        x_text = self.bert_llm(
            input_ids=ids_tokens,
            attention_mask=att_mask
        )
        x_text = x_text.last_hidden_state[:,0] #[CLS] token
        
        logits_img, logits_text = self.clip(img_features, x_text)
        
        pred = self.clasiffier(img_features)
        
        return pred, (logits_img, logits_text)
    
    def _calc_L1(self, params):
        l1 = 0
        for p in params:
            l1 += torch.sum(torch.abs(p))
        return self.hparams.lambda_L1 * l1
    
    def _calc_weights(self, log_weights:bool = True):
        self.loss_weights = self.softAdapt.get_component_weights(
            torch.tensor(self.loss_history[0][:-1] if len(self.loss_history[0]) % 2 == 0 else self.loss_history[0]),
            torch.tensor(self.loss_history[1][:-1] if len(self.loss_history[1]) % 2 == 0 else self.loss_history[1]),
            verbose=False
        )
        
        if log_weights:
            self.log("W_loss_img", self.loss_weights[0], on_epoch=True, on_step=False)
            self.log("W_loss_CLIP", self.loss_weights[1], on_epoch=True, on_step=False)

        self.loss_history = {
            0: [], 
            1: [],
        }
    
    def _shared_step(self, batch, val_step:bool = False):
        data, label = batch
        label = label.float()
        
        prediction, clip_logits = self(data)
        
        prediction = prediction.flatten()
        predicted_labels = torch.sigmoid(prediction)
        
        loss_list = [
            self.lossFN_clasiffier(prediction, label), 
            self.lossFN_CLIP(clip_logits)
        ]
                
        if not val_step:
            if self.hparams.lambda_L1 is not None:
                loss_list[0] += self._calc_L1(self.clasiffier.parameters())
            
            for i, l in enumerate(loss_list):
                self.loss_history[i].append(l.item())
        
        loss = 0
        for l, w in zip(loss_list, self.loss_weights):
            loss += w * l
        
        return loss, label, predicted_labels, loss_list
    
    def on_train_epoch_start(self):
        if self.current_epoch % 2 == 0 and self.current_epoch != 0:
            self._calc_weights()
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, pred_labels, losses = self._shared_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(pred_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        self.log("train_loss_img", losses[0], on_epoch=True, on_step=False)
        self.log("train_loss_clip", losses[1], on_epoch=True, on_step=False)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, pred_labels, losses = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(pred_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
        self.log("val_loss_img", losses[0], on_epoch=True, on_step=False)
        self.log("val_loss_clip", losses[1], on_epoch=True, on_step=False)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels, _ = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)  #, weight_decay=self.Lambda_L2
        
        return [optimizer]
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
    parser.add_argument("--Fold", type=int, default=0)
    parser.add_argument("-lv", "--Log_version", type=int)
    
    args = parser.parse_args()

    Dataset = Loaders.RARP_DatasetCreator(
        "./DataSet_Ando_All_no20Crop",
        FoldSeed=505,
        createFile=True,
        SavePath="./DataSet_AndoAll20_crop",
        Fold=5,
        removeBlackBar=False,
    )

    Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])

    Dataset.CreateFolds()
        
    setup_seed(2023)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    DumpCSV = pd.read_csv(Dataset.CVS_File)
    Extradata = pd.read_csv(Path("./DataSet_Ando_All_no20Crop/data_source_prompt.csv"))

    DumpCSV["raw_name"] = "Img0-" + DumpCSV["id"].astype(str) + ".npy"
    DumpCSV = DumpCSV.drop(columns=["id", "path", "mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3"])

    NewData = pd.merge(Extradata, DumpCSV, on="name")
        
    Fold = args.Fold
    InitResize=(512, 512)

    batchSize = 16
    numWorkers = 5
    MaxEpochs = 100
    LogFileName = "log_X22"

    rootFile = Dataset.CVS_File.parent.parent/f"fold_{Fold}"
    checkPtCallback = [
        callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'), 
        callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
    ]
    
    traintransform = torch.nn.Sequential(
        transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC), 
        transforms.RandomRotation(
            degrees=(-15, 15), 
            fill=5
        ),
        transforms.RandomResizedCrop(
                224, 
                scale=(0.4, 1), 
                antialias=True,
                interpolation=transforms.InterpolationMode.BICUBIC
            ),
        transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))
        ], 0.3),        
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)

    valtransform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)
    
    testtransform =  torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)
    
    bert_tokenizer = AutoTokenizer.from_pretrained(LLM)

    trainDataset = RARP_DatasetFolder_LLM(
        str (rootFile/"train"),
        loader = defs.load_file_tensor,
        Extra_Data = NewData,
        tokenizer = bert_tokenizer,
        extensions = "npy",
        transform = traintransform
    )

    valDataset = RARP_DatasetFolder_LLM(
        str (rootFile/"val"),
        loader = defs.load_file_tensor,
        Extra_Data = NewData,
        tokenizer = bert_tokenizer,
        extensions = "npy",
        transform = valtransform
    )
    
    testDataset = RARP_DatasetFolder_LLM(
        str (rootFile/"test"),
        loader=defs.load_file_tensor,
        Extra_Data = NewData,
        tokenizer = bert_tokenizer,
        extensions="npy",
        transform=testtransform
    )
        
    Train_DataLoader = DataLoader(
        trainDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=True, 
        drop_last=True,
        pin_memory=True,
        persistent_workers=numWorkers>0
    )
    Val_DataLoader = DataLoader(
        valDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=numWorkers>0
    )
    Test_DataLoader = DataLoader(
        testDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True,
        persistent_workers=numWorkers>0
    )

    match(args.Phase):
        case "train":
            Model = RARP_VAN_BERT(LLM, "van_b2_teacher_90_D2.pth", clasiffier_layers=[])
            
            trainer = L.Trainer(
                deterministic=True,
                accelerator='gpu', 
                devices=1, 
                logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
                log_every_n_steps=5,   
                callbacks=checkPtCallback,
                max_epochs=MaxEpochs,
            )
            print("Train Phase")
            
            trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
            trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")
        case "eval_all":
            print("Evaluation Phase")
            rows = []
            pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
            for ckpFile in pathCkptFile.glob("*.ckpt"):
                print(ckpFile.name)
                Model = RARP_VAN_BERT.load_from_checkpoint(ckpFile)
                Model.to(device)
                Model.eval()
                
                Predictions = []
                Labels = []
                
                with torch.no_grad():
                    for data, label in tqdm(iter(Test_DataLoader)):
                        data = data.float().to(device)
                        label = label.to(device)
                        
                        pred, _= Model(data)
                        pred = pred.flatten()
                        
                        Predictions.append(torch.sigmoid(pred))
                        Labels.append(label)
                        
                Predictions = torch.cat(Predictions)
                Labels = torch.cat(Labels)
                
                print(Predictions, Labels)

                acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
                precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
                recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
                auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
                f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
                specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
                    
                table = [
                    ["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
                ]
                
                for i in range(2):
                    aucCurve = torchmetrics.ROC("binary").to(device)
                    fpr, tpr, thhols = aucCurve(Predictions, Labels)
                    index = torch.argmax(tpr - fpr)
                    th2 = (recall + specificty - 1).item()
                    th2 = 0.5 if th2 <= 0 else th2
                    th1 = thhols[index].item() if i == 0 else th2
                    accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
                    precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
                    recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
                    specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
                    f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
                    #cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
                    #cm2.update(Predictions, Labels)
                    #_, ax = cm2.plot()
                    #ax.set_title(f"NVB Classifier (th={th1:.4f})")
                    table.append([
                        f"{th1:.4f}", 
                        f"{accY.item():.4f}", 
                        f"{precisionY.item():.4f}",
                        f"{recallY.item():.4f}", 
                        f"{f1ScoreY.item():.4f}", 
                        f"{auc.item():.4f}", 
                        f"{specifictyY.item():.4f}", 
                        ckpFile.name
                    ])
                
                rows += table
                
            df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])        
            df.style.highlight_max(color="red", axis=0)
            print(df)
    
    