import os
import torch
from torchvision.datasets.folder import make_dataset
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, Subset
from lightning.pytorch import LightningDataModule
import numpy as np
import cv2
from typing import Optional, Sequence, Tuple, Union, List, Any, Callable
from sklearn.model_selection import KFold, StratifiedKFold
import torchvision
import math
import defs
import csv
import pandas as pd
from enum import Enum
import shutil
from ultralytics import YOLO
import itertools

def _find_classes(dir):
    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
    classes.sort()
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

def get_samples(root, extensions=(".mp4", ".avi")):
    _, class_to_idx = _find_classes(root)
    return make_dataset(root, class_to_idx, extensions=extensions)

class RARP_DatasetCreator():
    def _removeBlackBorder(self, image):
        image = np.array(image)
        
        copyImg = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV)
        h = copyImg[:,:,0]
        mask = np.ones(h.shape, dtype=np.uint8) * 255
        th = (25, 175)
        mask[(h > th[0]) & (h < th[1])] = 0
        copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)
        resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)
            
        image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        crop = image[y : y + h, x : x + w]
        return crop
    
    def _crop(self, image:np, x:tuple, y:tuple):
        return image[y[0]:y[1], x[0]:x[1], :]
    
    def _sliceImage(self, img:np, size:float):
        h, w, _ = img.shape
        xw = math.trunc(w*size)
        xh = math.trunc(h*size)
        ul = self._crop(img, (0, xw), (0, xh))
        ur = self._crop(img, (w-xw, w), (0, xh))
        dl = self._crop(img, (0, xw), (h-xh, h))
        dr = self._crop(img, (w-xw, w), (h-xh, h))
        
        return [ul, ur, dl, dr]
    
    def _split(self, a, n):
        k, m = divmod(len(a), n)
        return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
    
    def ROI_Extract_YOLO(self, YoloModel:YOLO, image, threshold=0.75):
        listROIs = []
        res = YoloModel(image, stream=True)
        for r in res:
            pred = r.boxes.conf.cpu().numpy()
            for i, box in enumerate(r.boxes.xyxy.cpu().numpy()):
                if pred[i] >= threshold:
                    box = box.astype(int)
                    x, y = (box[0], box[1])
                    xw, yh = (box[2], box[3])
                    listROIs.append(image[y:yh, x:xw])
                        
        return listROIs
    
    def __init__(self,
                 RootPath = "", 
                 extension="tiff", 
                 FoldSeed=505, 
                 createFile=False, 
                 SavePath="", 
                 Fold:int=None, 
                 preResize=None, 
                 RGBGama=False, 
                 removeBlackBar=True,
                 SegImage:float=None,
                 Num_Img_Slices:int = 4,
                 SegmentClass:int = None,
                 colorSpace:int=None,
                 ROI_Yolo:YOLO=None,
                 thresholdYolo_Accuracy:float=0.75) -> None:
        
        root = Path(RootPath)
        lMean = []
        lStd = []
        NO_NVB = []
        NVB = []
                        
        if createFile:
            if len(SavePath) == 0: 
                raise Exception("If createFile is True, SavePath must have a value")
        
        sPath = Path(SavePath + f"_seed_{FoldSeed}")
        dumpImgs = sPath/"dump"
        dumpImgs.mkdir(parents=True, exist_ok=True)
        self.CVS_File = dumpImgs/"dataset.csv"
        if not self.CVS_File.exists():
            with open(self.CVS_File, "x", newline='') as csvfile: 
                writerOBJ = csv.writer(csvfile)
                writerOBJ.writerow(["id", "label", "class", "path", "name", "mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3"])
                id = 0
                for i, file in enumerate(root.glob(f"**/*.{extension}")):
                    tempImg = cv2.imread(str(file), cv2.IMREAD_COLOR)
                    if removeBlackBar and ROI_Yolo is None:
                        tempImg = self._removeBlackBorder(tempImg)
                    if RGBGama:
                        tempImg = cv2.cvtColor(tempImg, cv2.COLOR_BGR2RGB)
                    if colorSpace is not None:
                        tempImg = cv2.cvtColor(tempImg, colorSpace)
                    if preResize is not None:
                        tempImg = cv2.resize(tempImg, preResize, interpolation=cv2.INTER_AREA)#Cambio de resolucion
                    
                    tempImgList = [tempImg]
                    
                    #ROI w/ YOLOv8
                    if ROI_Yolo is not None:
                        tempImgList = self.ROI_Extract_YOLO(ROI_Yolo, tempImg, thresholdYolo_Accuracy)
                    
                    for k, tempImg in enumerate(tempImgList):
                        lineaCSV = [id, 0 if file.parent.name == "NOT_NVB" else 1, file.parent.name, (dumpImgs/f"Img{k}-{i}.npy").absolute(), file.name] 
                        lMean.append(np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))))
                        lStd.append(np.std(tempImg, axis=tuple(range(tempImg.ndim-1))))
                        lineaCSV += np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))).tolist()
                        lineaCSV += np.std(tempImg, axis=tuple(range(tempImg.ndim-1))).tolist()
                        writerOBJ.writerow(lineaCSV)
                        #lista.append (np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))))
                        np.save(dumpImgs/f"Img{k}-{i}.npy", tempImg)
                        if lineaCSV[1] == 0:
                            NO_NVB.append(id)
                        else:
                            NVB.append(id)
                        
                        id += 1
                        
                        if SegImage is not None:
                            if (SegmentClass is None) or (lineaCSV[1] == SegmentClass):
                                for j, newImg in enumerate(self._sliceImage(tempImg, SegImage)):
                                    if j == Num_Img_Slices:
                                        break
                                    imgPath = dumpImgs/f"Img{k}-{i}-{j}.npy"
                                    np.save(imgPath, newImg)
                                    lineaCSV = [id, 0 if file.parent.name == "NOT_NVB" else 1, file.parent.name, (dumpImgs/f"Img{k}-{i}-{j}.npy").absolute(), file.name] 
                                    lineaCSV += np.mean(newImg, axis=tuple(range(newImg.ndim-1))).tolist()
                                    lineaCSV += np.std(newImg, axis=tuple(range(newImg.ndim-1))).tolist()
                                    writerOBJ.writerow(lineaCSV)
                                    if lineaCSV[1] == 0:
                                        NO_NVB.append(id)
                                    else:
                                        NVB.append(id)
                                    
                                    id += 1
                    
                csvfile.close()
                
            temp = np.asarray(lMean)
            self.mean = list(np.mean(temp, axis=tuple(range(temp.ndim-1))))
            temp = np.asarray(lStd)
            self.std = list(np.mean(temp, axis=tuple(range(temp.ndim-1))))
        else:
            data = pd.read_csv(self.CVS_File)
            self.mean = data[["mean_1", "mean_2", "mean_3"]].mean().to_list()
            self.std = data[["std_1", "std_2", "std_3"]].mean().to_list()
            
            NO_NVB = data.loc[data["label"] == 0]["id"].to_list()
            NVB = data.loc[data["label"] == 1]["id"].to_list()
            
        if Fold is not None:
            self.Splits = []      
            splitsToSave = []      
            tempFoldsOrder = list(range(Fold))
            for _ in range(Fold):
                self.Splits.append(tempFoldsOrder)
                tempFoldsOrder = tempFoldsOrder[1:] + tempFoldsOrder[:1]
                
            if FoldSeed is not None:
                np.random.seed(FoldSeed)
                np.random.shuffle(NO_NVB)
                np.random.shuffle(NVB)

            NO_NVB_Folds = list (self._split(NO_NVB, Fold))
            NVB_Fols = list(self._split(NVB, Fold))
            
            setst = [
                math.trunc(0.60 * Fold), math.trunc(0.20 * Fold), math.trunc(0.20 * Fold)
            ]
            
            for data in self.Splits:
                ultimo = 0 
                for s in setst:
                    tempArry = []
                    for fold in data[ultimo: ultimo + s]:
                        tempArry += NO_NVB_Folds[fold] + NVB_Fols[fold]
                    splitsToSave.append(tempArry)
                    ultimo += s
            self.Folds_File = dumpImgs/"Folds.npy"
            np.save(self.Folds_File, np.asarray(splitsToSave, dtype=object))
            
    def CreateClases(self):
        if not (self.CVS_File.parent/"FOLDS_CREATED").exists():
            root = self.CVS_File.parent.parent
            database = pd.read_csv(self.CVS_File, usecols=["id", "label", "class", "path"])
            
            for _, row in database.iterrows():
                pathOriginal = Path(row["path"])
                pathNuevo = root/"dataset"/("NO_NVB" if row["label"] == 0 else "NVB")/pathOriginal.name
                pathNuevo.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy(pathOriginal, pathNuevo)
                
            with open(self.CVS_File.parent/"FOLDS_CREATED", "x") as file:
                file.close()
                
            for f in self.CVS_File.parent.glob("*.npy"):
                f.unlink()
                
    def ExtractROI_YOLO(self, YOLOModel:YOLO, thresholdYolo:float = 0.70):
        if (self.CVS_File.parent/"FOLDS_CREATED").exists() and not (self.CVS_File.parent/"YOLO_ROI_CREATED").exists():
            root = self.CVS_File.parent.parent
            for img in root.glob(f"**/*.npy"):
                if img.name == "Folds.npy":
                    continue
                tempImgList = self.ROI_Extract_YOLO(YOLOModel, np.load(img), thresholdYolo)
                for k, tempImg in enumerate(tempImgList):
                    if k == 0:
                        np.save(img, tempImg)
                    #else:
                    #    newPath = img.parent/(img.name.replace(".npy", "") +f"-{k}.npy")
                    #    np.save(newPath, tempImg) 
                        
            with open(self.CVS_File.parent/"YOLO_ROI_CREATED", "x") as file:
                file.close()  
                        
    def CreateFolds(self):
        if not (self.CVS_File.parent/"FOLDS_CREATED").exists():
            root = self.CVS_File.parent.parent
            database = pd.read_csv(self.CVS_File, usecols=["id", "label", "class", "path"])
            arrFolds = np.load(self.Folds_File, allow_pickle=True)

            for foldNum, splits in enumerate(np.array_split(arrFolds, len(arrFolds)/3)):
                foldPath = root/f"fold_{foldNum}"
                for datasetType, subSet in enumerate(splits):
                    SubfoldPath = foldPath/f"{RARP_DataSetType(datasetType).name}"
                    for _, row in database.loc[database["id"].isin(subSet)].iterrows():
                        PathOri = Path(row["path"])
                        PathImg = SubfoldPath/("NO_NVB" if row["label"] == 0 else "NVB")/PathOri.name
                        PathImg.parent.mkdir(parents=True, exist_ok=True)
                        shutil.copy(PathOri, PathImg)
                        
            with open(self.CVS_File.parent/"FOLDS_CREATED", "x") as file:
                file.close()
                
            for f in self.CVS_File.parent.glob("*.npy"):
                f.unlink()
                                
class RARP_DatasetFolder_ExtraData(torchvision.datasets.DatasetFolder):
    def __init__(self, 
                 root: str, 
                 loader: Callable[[str], Any], 
                 Extra_Data: pd.DataFrame,
                 Extra_Data_leg:int = 4,
                 extensions: Tuple[str, ...] | None = None, 
                 transform: Callable[..., Any] | None = None, 
                 target_transform: Callable[..., Any] | None = None, 
                 is_valid_file: Callable[[str], bool] | None = None
                ) -> None:
        super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)
        self.Extra_Data = Extra_Data
        self.Extra_Data_leg = Extra_Data_leg
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, target = self.samples[index]
        
        name = Path(path).name
        Extra_data = self.Extra_Data[self.Extra_Data["raw_name"] == name].values.flatten().tolist()[:self.Extra_Data_leg]
        Extra_data = torch.tensor(Extra_data)
        
        #Loadres is 2 values
        sample = self.loader(path) 
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (sample, Extra_data), target

class RARP_DatasetFolder_DobleTransform(torchvision.datasets.DatasetFolder):
    def __init__(self, 
                 root: str, 
                 loader: Callable[[str], Any], 
                 extensions: Tuple[str, ...] | None = None, 
                 transformT1: Callable[..., Any] | None = None, 
                 transformT2: Callable[..., Any] | None = None, 
                 target_transform: Callable[..., Any] | None = None, 
                 is_valid_file: Callable[[str], bool] | None = None,
                 passOriginal: Callable[..., Any] | None = None, 
                ) -> None:
        super().__init__(root, loader, extensions, None, target_transform, is_valid_file)
        self.T1 = transformT1
        self.T2 = transformT2 if transformT2 is not None else transformT1
        self.PassOrigianlImage = passOriginal
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, target = self.samples[index]
                
        sample = self.loader(path) 
        if self.T1 is not None:
            sample1 = self.T1(sample)
            
        if self.T2 is not None:
            sample2 = self.T2(sample)

        if self.PassOrigianlImage is not None:
            return (sample1, sample2, self.PassOrigianlImage(sample)), target
        else:
            return (sample1, sample2), target
            
 
class RARP_PreprocessCreator():
    def removeBlackBorder(self, image):
        image = np.array(image)
        image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        crop = image[y : y + h, x : x + w]
        return crop
    
    def _crop(self, image:np, x:tuple, y:tuple):
        return image[y[0]:y[1], x[0]:x[1], :]

    def _sliceImage(self, img:np, size:float):
        h, w, _ = img.shape
        ul = self._crop(img, (0, math.trunc(w*size)), (0, math.trunc(h*size)))
        ur = self._crop(img, (math.trunc(w*size), w), (0, math.trunc(h*size)))
        dl = self._crop(img, (0, math.trunc(w*size)), (math.trunc(h*size), h))
        dr = self._crop(img, (math.trunc(w*size), w), (math.trunc(h*size), h))
        
        return [ul, ur, dl, dr]
    
    def __init__(self, 
                 RootPath = "", 
                 extension="tiff", 
                 FoldSeed=505, 
                 createFile=False, 
                 SavePath="", 
                 Fold=False, 
                 preResize=None, 
                 RGBGama=False, 
                 removeBlackBar=True,
                 SegImage:float=None) -> None:
        super().__init__()

        self.root = Path(RootPath)
        self.mean = 0.0
        self.std = 0.0
        self.test_DataSet = []
        self.val_DataSet = []
        self.train_DataSet = []

        self.nNVB = []
        self.yNVB = []

        lista = []

        dist = {
            'test':  [5, 2], #[4, 3], [10, 4],
            'val':   [8, 3], #[4, 3], [10, 4],
            'train': [28, 10] #[8, 6], [21, 7] 
        }
        
        sPath = Path(SavePath + f"_seed_{FoldSeed}")
        dumpImgs = sPath/"dump"
        dumpImgs.mkdir(parents=True, exist_ok=True)
        for i, file in enumerate(self.root.glob(f"**/*.{extension}")):
            tempImg = cv2.imread(str(file), cv2.IMREAD_COLOR)
            if removeBlackBar:
                tempImg = self.removeBlackBorder(tempImg)
            if RGBGama:
                tempImg = cv2.cvtColor(tempImg, cv2.COLOR_BGR2RGB)
            if preResize is not None:
                tempImg = cv2.resize(tempImg, preResize, interpolation=cv2.INTER_AREA)#Cambio de resolucion

            lista.append (np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))))
            
            if not (sPath/"train").exists():
                imgPath = dumpImgs/f"Img{i}.npy"
                np.save(imgPath, tempImg)
                #parImgClass = (tempImg.astype(float), 0 if file.parent.name == "NOT_NVB" else 1)
                parImgClass = (imgPath, 0 if file.parent.name == "NOT_NVB" else 1)

                if file.parent.name == "NOT_NVB":
                    self.nNVB.append(parImgClass)
                else:
                    self.yNVB.append(parImgClass)
                    
                if SegImage is not None:
                    for j, newImg in enumerate(self._sliceImage(tempImg, SegImage)):
                        imgPath = dumpImgs/f"Img{i}-{j}.npy"
                        np.save(imgPath, newImg)
                        #parImgClass = (tempImg.astype(float), 0 if file.parent.name == "NOT_NVB" else 1)
                        parImgClass = (imgPath, 0 if file.parent.name == "NOT_NVB" else 1)

                        if file.parent.name == "NOT_NVB":
                            self.nNVB.append(parImgClass)
                        else:
                            self.yNVB.append(parImgClass)

        temp = np.asarray(lista)
        self.mean = list(np.mean(temp, axis=tuple(range(temp.ndim-1))))
        self.std = list(np.std(temp, axis=tuple(range(temp.ndim-1))))
        
        if (sPath / "train").exists():
            self.test_DataSet = None
            self.train_DataSet = None
            self.val_DataSet = None
            self.nNVB = None
            self.yNVB = None
            self.SaveFold = sPath
            return

        if FoldSeed is not None:
            np.random.seed(FoldSeed)

        shuffle_nNVB = list(range(len(self.nNVB)))
        shuffle_NVB = list(range(len(self.yNVB)))

        np.random.shuffle(shuffle_nNVB)
        np.random.shuffle(shuffle_NVB)

        pasoN, pasoY = (dist['train'][0], dist['train'][1])

        self.train_DataSet = self.nNVB[:pasoN] + self.yNVB[:pasoY]

        nextPasoN, nextPasoY = (pasoN+dist['val'][0], pasoY+dist['val'][1])

        self.val_DataSet = self.nNVB[pasoN:nextPasoN] + self.yNVB[pasoY:nextPasoY]

        self.test_DataSet = self.nNVB[nextPasoN:] + self.yNVB[nextPasoY:]

        if createFile:
            if len(SavePath) == 0: 
                raise Exception("If createFile is True, SavePath must have a value")
            
            n=0
            for img, lb in self.train_DataSet:
                trainPath = sPath / "train" / ("NOT_NVB" if lb == 0 else "NVB") / f"img{n}"
                trainPath.parent.mkdir(parents=True, exist_ok=True)
                np.save(trainPath, defs.load_file(img))
                n += 1

            n=0 if not Fold else n
            for img, lb in self.val_DataSet:
                trainPath = sPath / ("val" if not Fold else "train") / ("NOT_NVB" if lb == 0 else "NVB") / f"img{n}"
                trainPath.parent.mkdir(parents=True, exist_ok=True)
                np.save(trainPath, defs.load_file(img))
                n += 1

            n=0 if not Fold else n
            for img, lb in self.test_DataSet:
                trainPath = sPath / ("test" if not Fold else "train") / ("NOT_NVB" if lb == 0 else "NVB") / f"img{n}"
                trainPath.parent.mkdir(parents=True, exist_ok=True)
                np.save(trainPath, defs.load_file(img))
                n += 1

            self.test_DataSet = None
            self.train_DataSet = None
            self.val_DataSet = None
            self.nNVB = None
            self.yNVB = None
            self.SaveFold = sPath
            
            for f in dumpImgs.iterdir():
                f.unlink()
            dumpImgs.rmdir()

class DataloaderToDataModule(LightningDataModule):
    """Converts a set of dataloaders into a lightning datamodule.

    Args:
        train_dataloader: Training dataloader
        val_dataloaders: Validation dataloader(s)
        test_dataloaders: Test dataloader(s)
    """

    def __init__(
        self,
        train_dataloader: DataLoader,
        val_dataloaders: Union[DataLoader, Sequence[DataLoader]],
    ) -> None:
        super().__init__()
        self._train_dataloader = train_dataloader
        self._val_dataloaders = val_dataloaders

    def train_dataloader(self) -> DataLoader:
        """Return training dataloader."""
        return self._train_dataloader

    def val_dataloader(self) -> Union[DataLoader, Sequence[DataLoader]]:
        """Return validation dataloader(s)."""
        return self._val_dataloaders

class RARP_DataModule(LightningDataModule):
    def __init__(self, 
                 KFold:bool=False, 
                 num_fold:int=5, 
                 shuffle: bool = False,
                 stratified: bool = False,
                 train_dataloader: DataLoader = None,
                 val_dataloader: DataLoader = None,
                 test_dataloader: DataLoader = None,
                 transforms = []) -> None:
        super().__init__()

        self.fold_index = 0
        self.splits = None
        self.Kfold_split = KFold
        self.num_folds = num_fold
        self.Folds = []
        self.shuffle = shuffle
        self.stratified = stratified
        self.label_extractor = lambda batch: batch[1]

        self.trainDL = train_dataloader
        self.valDL = val_dataloader 
        self.testDL = test_dataloader
        self.dataloader_settings = None
        self.transforms = transforms

        tempFoldsOrder = list(range(num_fold))
        for _ in range(num_fold):
            self.Folds.append(tempFoldsOrder)
            tempFoldsOrder = tempFoldsOrder[1:] + tempFoldsOrder[:1]

    def get_Current_Folds(self) -> list:
        return self.Folds[self.fold_index]
    
    def get_labels(self, dataloader: DataLoader) -> Optional[List]:
        """Try to extract the training labels (for classification problems) from the underlying training dataset."""
        # Try to extract labels from the dataset through labels attribute
        if hasattr(dataloader.dataset, "labels"):
            return dataloader.dataset.labels.tolist()

        # Else iterate and try to extract
        try:
            return torch.cat([self.label_extractor(batch) for batch in dataloader], dim=0).tolist()
        except Exception:
            return None
            
    def _split(self, a, n):
        k, m = divmod(len(a), n)
        return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
    
    def split_folds(self):
        if self.splits is None:
            self.splits = []
            self.FoldDataset = self.trainDL.dataset

            NO_NVB = []
            NVB = []
            for pos, label in enumerate(self.FoldDataset.targets):
                if label==0:
                    NO_NVB.append(pos) 
                else:
                    NVB.append(pos)

            if self.shuffle:
                np.random.shuffle(NO_NVB)
                np.random.shuffle(NVB)

            NO_NVB_Folds = list (self._split(NO_NVB, self.num_folds))
            NVB_Fols = list(self._split(NVB, self.num_folds))

            setst = [
                math.trunc(0.60 * self.num_folds), math.trunc(0.80 * self.num_folds), self.num_folds
            ]

            tempArray = [] ##### FIX ... que puede hacer los grupos mal sin los folds son difrentes de 5
            for pos, index in enumerate(self.Folds[self.fold_index]):
                if pos < setst[0]:
                    tempArray += NO_NVB_Folds[index] + NVB_Fols[index]
                elif pos < setst[1]:
                    self.splits.append(tempArray)
                    self.splits.append(NO_NVB_Folds[index] + NVB_Fols[index])
                else:
                    self.splits.append(NO_NVB_Folds[index] + NVB_Fols[index])

        
        
    def setup_folds(self):
        if self.splits is None:
            labels = None
            if self.stratified:
                labels = self.get_labels(self.trainDL)
                if labels is None:
                    raise ValueError(
                        "Tried to extract labels for stratified K folds but failed."
                        " Make sure that the dataset of your train dataloader either"
                        " has an attribute `labels` or that `label_extractor` attribute"
                        " is initialized correctly"
                    )
                splitter = StratifiedKFold(self.num_folds, shuffle=self.shuffle)
            else:
                splitter = KFold(self.num_folds, shuffle=self.shuffle)

            self.FoldDataset = self.trainDL.dataset

            self.splits = [split for split in splitter.split(range(len(self.FoldDataset)), y=labels)]

    def train_dataloader(self) -> DataLoader:
        if not self.Kfold_split:
            return self.trainDL
        else:
            self.split_folds()
            train_fold = RARP_Dataset(self.FoldDataset, self.splits[0], self.transforms[0])
            return DataLoader(train_fold, shuffle=True, **self.dataloader_setting)
        
    def val_dataloader(self) -> DataLoader:
        if not self.Kfold_split:
            return self.valDL
        else:
            self.split_folds()
            val_fold = RARP_Dataset(self.FoldDataset, self.splits[1], self.transforms[1])
            return DataLoader(val_fold, shuffle=False, **self.dataloader_setting)
        
    def test_dataloader(self) -> DataLoader:
        if not self.Kfold_split:
            return self.testDL
        else:
            self.split_folds()
            test_fold = RARP_Dataset(self.FoldDataset, self.splits[2], self.transforms[2])
            return DataLoader(test_fold, shuffle=False, **self.dataloader_setting)
    
    @property
    def dataloader_setting(self) -> dict:
        """Return the settings of the train dataloader."""
        if self.dataloader_settings is None:
            orig_dl = self.trainDL
            self.dataloader_settings = {
                "batch_size": orig_dl.batch_size,
                "num_workers": orig_dl.num_workers,
                "collate_fn": orig_dl.collate_fn,
                "pin_memory": orig_dl.pin_memory,
                "drop_last": orig_dl.drop_last,
                "timeout": orig_dl.timeout,
                "worker_init_fn": orig_dl.worker_init_fn,
                "prefetch_factor": orig_dl.prefetch_factor,
                "persistent_workers": orig_dl.persistent_workers,
            }
        return self.dataloader_settings

class RARP_Dataset(Dataset):
    def __init__(self, RARP_dataset, indices, transform=None) -> None:
        super().__init__()
        self.transform = transform
        self.samples = Subset(RARP_dataset, indices)
        self.targets = [s[1] for s in self.samples]
        self.classes = ["NO_NVB", "NVB"] ## **** Hacer esta parte correctamente */*** 

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img, label = self.samples[index]

        if self.transform is not None:
            img = self.transform(img)

        return (img, label)
    
class RARP_ChannelSwap(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
    def forward(self, img):
        # It is assumed that the input image is in RGB format.
        Channels = [0, 1, 2]
        np.random.shuffle(Channels)
        
        return img[Channels]
    
class RARP_Invert(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
    def forward(self, img):
        return 255 - img
    
class RARP_DINO_Augmentation():
    def __init__(
        self, 
        GloblaCropsScale=(0.4, 1), 
        LocalCropsScale=(0.05, 0.4), 
        NumLocalCrops:int=8, 
        Size:int=224, 
        device = None, 
        mean:float = None, 
        std:float = None,
        Tranform_0 = None
    ) -> None:
        self.NumLocal_Crops= NumLocalCrops
       
        self.globalCrop1 = torch.nn.Sequential(
            torchvision.transforms.RandomResizedCrop(
                Size, 
                scale=GloblaCropsScale, 
                antialias=True,
                interpolation=torchvision.transforms.InterpolationMode.BICUBIC
            ),
            torchvision.transforms.RandomHorizontalFlip(0.5),
            torchvision.transforms.RandomErasing(0.2, value="random"),
            #torchvision.transforms.RandomApply([
            #    RARP_ChannelSwap()
            #]),
            torchvision.transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2)),
            torchvision.transforms.Normalize(mean, std)
        ).to(device)
        
        self.globalCrop2 = torch.nn.Sequential(
            torchvision.transforms.RandomResizedCrop(
                Size, 
                scale=GloblaCropsScale,
                antialias=True,
                interpolation=torchvision.transforms.InterpolationMode.BICUBIC
            ),
            torchvision.transforms.RandomHorizontalFlip(0.5),
            torchvision.transforms.RandomErasing(0.8, value="random"),
            #torchvision.transforms.RandomApply([
            #    RARP_ChannelSwap()
            #]),
            torchvision.transforms.RandomApply([
                torchvision.transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2))
            ], 0.1),
            torchvision.transforms.RandomApply([
                RARP_Invert()
            ], 0.2),
            torchvision.transforms.Normalize(mean, std)
        ).to(device)
        
        self.local = torch.nn.Sequential(
            torchvision.transforms.RandomResizedCrop(
                Size, 
                scale=LocalCropsScale,
                antialias=True, 
                interpolation=torchvision.transforms.InterpolationMode.BICUBIC
            ),
            torchvision.transforms.RandomHorizontalFlip(0.5),
            torchvision.transforms.RandomErasing(0.5, value="random"),
            #torchvision.transforms.RandomApply([
            #    RARP_ChannelSwap()
            #]),
            torchvision.transforms.RandomApply([
                torchvision.transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2))
            ], 0.5),
            torchvision.transforms.Normalize(mean, std)
        ).to(device)
        
        InitResize = (256,256)        
        self.classification = torch.nn.Sequential(
            torchvision.transforms.Resize(
                InitResize, 
                antialias=True, 
                interpolation=torchvision.transforms.InterpolationMode.BICUBIC
            ),
            torchvision.transforms.CenterCrop(224),    
            torchvision.transforms.RandomAffine(
                degrees=(-5, 5), scale=(0.9, 1.1), fill=5
            ),
            torchvision.transforms.RandomHorizontalFlip(1.0),
            torchvision.transforms.Normalize(mean, std),
        ).to(device) if Tranform_0 == None else Tranform_0
                
    def __call__(self, img):
        all_Crops = []
        all_Crops.append(self.classification(img))
        all_Crops.append(self.globalCrop1(img))
        all_Crops.append(self.globalCrop2(img))
        
        all_Crops.extend([self.local(img) for _ in range(self.NumLocal_Crops)])
        
        return all_Crops

class RARP_DataSetType(Enum):
    train = 0
    val = 1
    test = 2

class RandomVideoDataset(torch.utils.data.IterableDataset):
    def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
        super().__init__()

        self.samples = get_samples(root)

        # Allow for temporal jittering
        if epoch_size is None:
            epoch_size = len(self.samples)
        self.epoch_size = epoch_size

        self.clip_len = clip_len
        self.frame_transform = frame_transform
        self.video_transform = video_transform

    def __iter__(self):
        for i in range(self.epoch_size):
            # Get random sample
            path, target = np.random.choice(self.samples)
            # Get video object
            vid = torchvision.io.VideoReader(path, "video")
            metadata = vid.get_metadata()
            video_frames = []  # video frame buffer

            # Seek and return frames
            max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
            start = np.random.uniform(0., max_seek)
            for frame in itertools.islice(vid.seek(start), self.clip_len):
                video_frames.append(self.frame_transform(frame['data']))
                current_pts = frame['pts']
            # Stack it into a tensor
            video = torch.stack(video_frames, 0)
            if self.video_transform:
                video = self.video_transform(video)
            output = {
                'path': path,
                'video': video,
                'target': target,
                'start': start,
                'end': current_pts}
            yield output

class VideoDataset(Dataset):
    def __init__(self, time_depth, mean, std, transform=None) -> None:
        super().__init__()

class RARPDataset(Dataset):
    def __init__(self, path_RARP_dataset:Path, path_RARP_Folds:Path=None, split=0, DataSetType:RARP_DataSetType="train", transform=None) -> None:
        super().__init__()
        self.samples = pd.read_csv(path_RARP_dataset, usecols=["id", "label", "class", "path"])
        self.classes = self.samples["class"].unique().tolist()
        self.targets = self.samples["label"].to_list()
        self.IDsplit = None
        self.transform = transform
                
        if path_RARP_Folds is not None:
            if split is None:
                raise Exception("Is required the split index to do Folds")
            
            arrFolds = np.load(path_RARP_Folds, allow_pickle=True)
            self.IDsplit = np.array_split(arrFolds, len(arrFolds)/3)[split][DataSetType.value]
            
            self.samples = self.samples.loc[self.samples["id"].isin(self.IDsplit)]
            self.targets = self.samples["label"]
            
    def __len__(self):
        return self.samples.shape[0]
    
    def __getitem__(self, index):
        ID = index
        if self.IDsplit is not None:
            ID = self.IDsplit[index]
            
        rs = self.samples.loc[self.samples["id"] == ID]
            
        img = defs.load_file_tensor(rs["path"].item())
        label = rs["label"].item()
        
        if self.transform is not None:
            img = self.transform(img)
            
        return img, label
