import os
import torch
from torchvision.datasets.folder import make_dataset
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, Subset
from lightning.pytorch import LightningDataModule
import lightning as L
import numpy as np
import cv2
from typing import Optional, Sequence, Tuple, Union, List, Any, Callable
from sklearn.model_selection import KFold, StratifiedKFold
import torchvision
import torchvision.transforms as T
import json
import math
import defs
import csv
import pandas as pd
from enum import Enum
import shutil
from ultralytics import YOLO
import itertools
import albumentations as A
from albumentations.pytorch import ToTensorV2

class NVB_Classes(Enum):
    NOT_NVB = 0
    R_NVB = 1
    L_NVB = 2
    RL_NVB = 3
    
class NVB_Binary(Enum):
    NOT_NVB = 0
    NVB = 1

def _find_classes(dir):
    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
    classes.sort()
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

def get_samples(root, extensions=(".mp4", ".avi")):
    _, class_to_idx = _find_classes(root)
    return make_dataset(root, class_to_idx, extensions=extensions)

class RARP_DatasetCreator():
    def _removeBlackBorder(self, image):
        image = np.array(image)
        
        copyImg = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV)
        h = copyImg[:,:,0]
        mask = np.ones(h.shape, dtype=np.uint8) * 255
        th = (25, 175)
        mask[(h > th[0]) & (h < th[1])] = 0
        copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)
        resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)
            
        image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        crop = image[y : y + h, x : x + w]
        return crop
    
    def _crop(self, image:np, x:tuple, y:tuple):
        return image[y[0]:y[1], x[0]:x[1], :]
    
    def _sliceImage(self, img:np, size:float):
        h, w, _ = img.shape
        xw = math.trunc(w*size)
        xh = math.trunc(h*size)
        ul = self._crop(img, (0, xw), (0, xh))
        ur = self._crop(img, (w-xw, w), (0, xh))
        dl = self._crop(img, (0, xw), (h-xh, h))
        dr = self._crop(img, (w-xw, w), (h-xh, h))
        
        return [ul, ur, dl, dr]
    
    def _split(self, a, n):
        k, m = divmod(len(a), n)
        return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
    
    def _ROI_Mask(self, input_img:np.ndarray):
        transform = A.Compose(
            [
                A.Resize(224, 224, interpolation=cv2.INTER_CUBIC),
                A.Normalize(self.mean, self.std),
                ToTensorV2()
            ]
        )
        
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.Masked_roi_model.to(device)
        self.Masked_roi_model.eval()
        
        h, w, _ = input_img.shape
        sample = input_img.copy()
        
        input_img = transform(image=input_img)
        input_img = input_img["image"]
        input_img = input_img.repeat(1, 1, 1, 1)
        input_img = input_img.to(device)
        
        with torch.no_grad():
            mask = torch.sigmoid(self.Masked_roi_model(input_img))
            
        mask = mask[0].cpu().numpy().transpose((1, 2, 0))
        mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC)
        mask = cv2.GaussianBlur(mask, (5, 5), 0)
        
        _, th_mask = cv2.threshold(mask, 0.5, 1, cv2.THRESH_BINARY)
        th_mask = cv2.morphologyEx(th_mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)))
        th_mask = th_mask.astype(np.uint8)
        
        contours = cv2.findContours(th_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        
        roi = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(roi)
        
        return cv2.bitwise_and(sample, sample, mask=th_mask)[y : y + h, x : x + w]
        
    
    def ROI_Extract_YOLO(self, YoloModel:YOLO, image, threshold=0.75):
        listROIs = []
        res = YoloModel(image, stream=True)
        for r in res:
            pred = r.boxes.conf.cpu().numpy()
            for i, box in enumerate(r.boxes.xyxy.cpu().numpy()):
                if pred[i] >= threshold:
                    box = box.astype(int)
                    x, y = (box[0], box[1])
                    xw, yh = (box[2], box[3])
                    listROIs.append(image[y:yh, x:xw])
                        
        return listROIs
    
    def __init__(
        self,
        RootPath = "", 
        extension="tiff", 
        FoldSeed=505, 
        createFile=False, 
        SavePath="", 
        Fold:int=None, 
        preResize=None, 
        RGBGama=False, 
        removeBlackBar=True,
        SegImage:float=None,
        Num_Img_Slices:int = 4,
        SegmentClass:int = None,
        colorSpace:int=None,
        ROI_Yolo:YOLO=None,
        thresholdYolo_Accuracy:float=0.75,
        Num_Labels:int = None,
        ROI_Mask:L.LightningModule=None
    ) -> None:
        
        root = Path(RootPath)
        lMean = []
        lStd = []
        NO_NVB = []
        NVB = []
        self.Num_Labels = None
        
        self.Masked_roi_model = ROI_Mask
        
        if Num_Labels is not None:
            MultiClasses = [[] for _ in range(Num_Labels)]
            self.Num_Labels = Num_Labels
                        
        if createFile:
            if len(SavePath) == 0: 
                raise Exception("If createFile is True, SavePath must have a value")
        
        sPath = Path(SavePath + f"_seed_{FoldSeed}")
        dumpImgs = sPath/"dump"
        dumpImgs.mkdir(parents=True, exist_ok=True)
        self.CVS_File = dumpImgs/"dataset.csv"
        if not self.CVS_File.exists():
            with open(self.CVS_File, "x", newline='') as csvfile: 
                writerOBJ = csv.writer(csvfile)
                writerOBJ.writerow(["id", "label", "class", "path", "name", "mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3"])
                id = 0
                for i, file in enumerate(root.glob(f"**/*.{extension}")):
                    tempImg = cv2.imread(str(file), cv2.IMREAD_COLOR)
                    if removeBlackBar and ROI_Yolo is None:
                        tempImg = self._removeBlackBorder(tempImg)
                    if RGBGama:
                        tempImg = cv2.cvtColor(tempImg, cv2.COLOR_BGR2RGB)
                    if colorSpace is not None:
                        tempImg = cv2.cvtColor(tempImg, colorSpace)
                    if preResize is not None:
                        tempImg = cv2.resize(tempImg, preResize, interpolation=cv2.INTER_AREA)#Cambio de resolucion
                    tempImgList = [tempImg]
                    
                    #ROI w/ YOLOv8
                    if ROI_Yolo is not None:
                        tempImgList = self.ROI_Extract_YOLO(ROI_Yolo, tempImg, thresholdYolo_Accuracy)
                    
                    for k, tempImg in enumerate(tempImgList):
                        idClass = NVB_Binary[file.parent.name].value if Num_Labels is None else NVB_Classes[file.parent.name].value
                        lineaCSV = [id, idClass, file.parent.name, (dumpImgs/f"Img{k}-{i}.npy").absolute(), file.name] 
                        lMean.append(np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))))
                        lStd.append(np.std(tempImg, axis=tuple(range(tempImg.ndim-1))))
                        lineaCSV += np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))).tolist()
                        lineaCSV += np.std(tempImg, axis=tuple(range(tempImg.ndim-1))).tolist()
                        writerOBJ.writerow(lineaCSV)
                        #lista.append (np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))))
                        np.save(dumpImgs/f"Img{k}-{i}.npy", tempImg)
                        if Num_Labels is None:
                            if lineaCSV[1] == 0:
                                NO_NVB.append(id)
                            else:
                                NVB.append(id)
                        else:
                            MultiClasses[idClass].append(id)
                        
                        id += 1
                        
                        if SegImage is not None:
                            if (SegmentClass is None) or (lineaCSV[1] == SegmentClass):
                                for j, newImg in enumerate(self._sliceImage(tempImg, SegImage)):
                                    if j == Num_Img_Slices:
                                        break
                                    imgPath = dumpImgs/f"Img{k}-{i}-{j}.npy"
                                    np.save(imgPath, newImg)
                                    idClass = NVB_Binary[file.parent.name].value if Num_Labels is None else NVB_Classes[file.parent.name].value
                                    lineaCSV = [id, idClass, file.parent.name, (dumpImgs/f"Img{k}-{i}-{j}.npy").absolute(), file.name] 
                                    lineaCSV += np.mean(newImg, axis=tuple(range(newImg.ndim-1))).tolist()
                                    lineaCSV += np.std(newImg, axis=tuple(range(newImg.ndim-1))).tolist()
                                    writerOBJ.writerow(lineaCSV)
                                    if Num_Labels is None:
                                        if lineaCSV[1] == 0:
                                            NO_NVB.append(id)
                                        else:
                                            NVB.append(id)
                                    else:
                                        MultiClasses[idClass].append(id)
                                    
                                    id += 1
                    
                csvfile.close()
                
            temp = np.asarray(lMean)
            self.mean = list(np.mean(temp, axis=tuple(range(temp.ndim-1))))
            temp = np.asarray(lStd)
            self.std = list(np.mean(temp, axis=tuple(range(temp.ndim-1))))
        else:
            data = pd.read_csv(self.CVS_File)
            self.mean = data[["mean_1", "mean_2", "mean_3"]].mean().to_list()
            self.std = data[["std_1", "std_2", "std_3"]].mean().to_list()
            if Num_Labels is None:
                NO_NVB = data.loc[data["label"] == 0]["id"].to_list()
                NVB = data.loc[data["label"] == 1]["id"].to_list()
            else:
                for i in range(Num_Labels):
                    MultiClasses[i] = data.loc[data["label"] == i]["id"].to_list()
            
        if Fold is not None:
            self.Splits = []      
            splitsToSave = []      
            tempFoldsOrder = list(range(Fold))
            for _ in range(Fold):
                self.Splits.append(tempFoldsOrder)
                tempFoldsOrder = tempFoldsOrder[1:] + tempFoldsOrder[:1]
                
            if FoldSeed is not None:
                np.random.seed(FoldSeed)
                if Num_Labels is None:
                    np.random.shuffle(NO_NVB)
                    np.random.shuffle(NVB)
                else:
                    for i in range(Num_Labels):
                        np.random.shuffle(MultiClasses[i])
            if Num_Labels is None:
                NO_NVB_Folds = list (self._split(NO_NVB, Fold))
                NVB_Fols = list(self._split(NVB, Fold))
            else:
                MultiClasses_Folds = []
                for i in range(Num_Labels):
                    MultiClasses_Folds.append(list(self._split(MultiClasses[i], Fold)))
            
            setst = [
                math.trunc(0.60 * Fold), math.trunc(0.20 * Fold), math.trunc(0.20 * Fold)
            ]
            
            for data in self.Splits:
                ultimo = 0 
                for s in setst:
                    tempArry = []
                    for fold in data[ultimo: ultimo + s]:
                        if Num_Labels is None:
                            tempArry += NO_NVB_Folds[fold] + NVB_Fols[fold]
                        else:
                            for i in range(Num_Labels):
                                tempArry += MultiClasses_Folds[i][fold]
                                
                    splitsToSave.append(tempArry)
                    ultimo += s
            self.Folds_File = dumpImgs/"Folds.npy"
            np.save(self.Folds_File, np.asarray(splitsToSave, dtype=object))
            
    def CreateClases(self):
        if not (self.CVS_File.parent/"FOLDS_CREATED").exists():
            root = self.CVS_File.parent.parent
            database = pd.read_csv(self.CVS_File, usecols=["id", "label", "class", "path"])
            
            for _, row in database.iterrows():
                pathOriginal = Path(row["path"])
                pathNuevo = root/"dataset"/("NO_NVB" if row["label"] == 0 else "NVB")/pathOriginal.name
                pathNuevo.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy(pathOriginal, pathNuevo)
                
            with open(self.CVS_File.parent/"FOLDS_CREATED", "x") as file:
                file.close()
                
            for f in self.CVS_File.parent.glob("*.npy"):
                f.unlink()
                
    def ExtractROI_YOLO(self, YOLOModel:YOLO, thresholdYolo:float = 0.70):
        if (self.CVS_File.parent/"FOLDS_CREATED").exists() and not (self.CVS_File.parent/"YOLO_ROI_CREATED").exists():
            root = self.CVS_File.parent.parent
            for img in root.glob(f"**/*.npy"):
                if img.name == "Folds.npy":
                    continue
                tempImgList = self.ROI_Extract_YOLO(YOLOModel, np.load(img), thresholdYolo)
                for k, tempImg in enumerate(tempImgList):
                    if k == 0:
                        np.save(img, tempImg)
                    #else:
                    #    newPath = img.parent/(img.name.replace(".npy", "") +f"-{k}.npy")
                    #    np.save(newPath, tempImg) 
                        
            with open(self.CVS_File.parent/"YOLO_ROI_CREATED", "x") as file:
                file.close()  
                        
    def CreateFolds(self):
        if not (self.CVS_File.parent/"FOLDS_CREATED").exists():
            root = self.CVS_File.parent.parent
            database = pd.read_csv(self.CVS_File, usecols=["id", "label", "class", "path"])
            arrFolds = np.load(self.Folds_File, allow_pickle=True)
            
            if self.Masked_roi_model is not None:
                database.reset_index()
                for _, row in database.iterrows():
                    img = np.load(Path(row["path"]))
                    np.save(Path(row["path"]), self._ROI_Mask(img))
            
            database.reset_index()
            for foldNum, splits in enumerate(np.array_split(arrFolds, len(arrFolds)/3)):
                foldPath = root/f"fold_{foldNum}"
                for datasetType, subSet in enumerate(splits):
                    SubfoldPath = foldPath/f"{RARP_DataSetType(datasetType).name}"
                    for _, row in database.loc[database["id"].isin(subSet)].iterrows():
                        PathOri = Path(row["path"])
                        if self.Num_Labels is not None:
                            folderName = f"{row['label']}_" + ("NO_NVB" if row["label"] == 0 else NVB_Classes(row["label"]).name)
                        else:
                            folderName = "NO_NVB" if row["label"] == 0 else "NVB" 
                        PathImg = SubfoldPath/folderName/PathOri.name
                        PathImg.parent.mkdir(parents=True, exist_ok=True)
                        shutil.copy(PathOri, PathImg)
                        
            with open(self.CVS_File.parent/"FOLDS_CREATED", "x") as file:
                file.close()
                
            for f in self.CVS_File.parent.glob("*.npy"):
                f.unlink()
                                
class RARP_DatasetFolder_ExtraData(torchvision.datasets.DatasetFolder):
    def __init__(self, 
                 root: str, 
                 loader: Callable[[str], Any], 
                 Extra_Data: pd.DataFrame,
                 Extra_Data_leg:int = 4,
                 extensions: Tuple[str, ...] | None = None, 
                 transform: Callable[..., Any] | None = None, 
                 target_transform: Callable[..., Any] | None = None, 
                 is_valid_file: Callable[[str], bool] | None = None
                ) -> None:
        super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)
        self.Extra_Data = Extra_Data
        self.Extra_Data_leg = Extra_Data_leg
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, target = self.samples[index]
        
        name = Path(path).name
        Extra_data = self.Extra_Data[self.Extra_Data["raw_name"] == name].values.flatten().tolist()[:self.Extra_Data_leg]
        Extra_data = torch.tensor(Extra_data)
        
        #Loadres is 2 values
        sample = self.loader(path) 
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (sample, Extra_data), target
    
class RARP_DatasetFolder_RoiExtractor(torchvision.datasets.DatasetFolder):
    def __init__(
        self, 
        root, 
        loader, 
        extensions = None, 
        transform = None, 
        target_transform = None, 
        is_valid_file = None,
        create_mask:bool = False
    ):
        super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)
        
        self.create_mask = create_mask
        
    def _removeBlackBorder_mask(self, image:np.ndarray, input_mask:np.ndarray):
        #image = np.array(image)
        
        copyImg = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV)
        h = copyImg[:,:,0]
        mask = np.ones(h.shape, dtype=np.uint8) * 255
        th = (25, 175)
        mask[(h > th[0]) & (h < th[1])] = 0
        copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)
        resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)
            
        image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        
        return image[y : y + h, x : x + w], input_mask[y : y + h, x : x + w]
    
    def _catmull_rom_spline(self, P0, P1, P2, P3, n_points=20):
        points = []
        for t in np.linspace(0, 1, n_points):
            # Catmull-Rom formula
            t2 = t * t
            t3 = t2 * t
            x = 0.5 * ((2 * P1[0]) +
                    (-P0[0] + P2[0]) * t +
                    (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t2 +
                    (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t3)
            
            y = 0.5 * ((2 * P1[1]) +
                    (-P0[1] + P2[1]) * t +
                    (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t2 +
                    (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t3)
            
            points.append((x, y))
            
        return points

    def _catmull_rom_closed_loop(self, points, n_points=20):
        spline_points = []
        n = len(points)
        
        for i in range(n):
            P0 = points[(i - 1) % n]
            P1 = points[i]
            P2 = points[(i + 1) % n]
            P3 = points[(i + 2) % n]
            spline_points += self._catmull_rom_spline(P0, P1, P2, P3, n_points)
            
        return np.array(spline_points)
    
    def _create_mask_from_contour(self, spline_points:np, mask_size:Tuple = (0, 0)):
        smooth_curve_int = np.round(spline_points).astype(np.int32)
        mask = np.zeros(mask_size, dtype=np.uint8)
        
        return cv2.fillPoly(mask, [smooth_curve_int], 1)
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, _ = self.samples[index]
        
        pth = Path(path)
        newPth = pth.parent / f"{pth.name.split('.')[0]}.json"
        
        data = json.load(open(newPth))
        kpts = data["shapes"][0]["points"]
        
        img = self.loader(path) 
        if self.transform is not None:
            if not self.create_mask:
                sample = self.transform(image=img, keypoints=kpts)
                
                img = sample["image"]
                kpts = torch.tensor(sample["keypoints"])
                _, h, w = img.shape
                kpts = kpts / torch.tensor([h, w])
            else:
                h, w, _ = img.shape
                smood_perimeter = self._catmull_rom_closed_loop(kpts, n_points=15)
                roi_mask = self._create_mask_from_contour(smood_perimeter, (h, w))
                
                crop_img, crop_mask = self._removeBlackBorder_mask(img, roi_mask)
                crop_mask = crop_mask.astype(np.float32)
                
                sample = self.transform(image=crop_img, mask=crop_mask)
                img = sample["image"]
                kpts = sample["mask"]
            
            
        return img, kpts
        

class RARP_DatasetFolder_ExtraLabel(torchvision.datasets.DatasetFolder):
    def __init__(
        self, 
        root: str, 
        loader: Callable[[str], Any], 
        Extra_Data: pd.DataFrame,
        extensions: Tuple[str, ...] | None = None, 
        transform: Callable[..., Any] | None = None, 
        target_transform: Callable[..., Any] | None = None, 
        is_valid_file: Callable[[str], bool] | None = None
    ) -> None:
        super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)
        self.Extra_Data = Extra_Data
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, _ = self.samples[index]
        
        name = Path(path).name
        Extra_data = [int(x) for x in str(self.Extra_Data[self.Extra_Data["raw_name"] == name]["encode_l1"].values[0]).split("|")]
        Extra_data = torch.tensor(Extra_data)
        
        #Loadres is 2 values
        sample = self.loader(path) 
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            Extra_data = self.target_transform(Extra_data)

        return sample, Extra_data

class RARP_DatasetFolder_DobleTransform(torchvision.datasets.DatasetFolder):
    def __init__(self, 
                 root: str, 
                 loader: Callable[[str], Any], 
                 extensions: Tuple[str, ...] | None = None, 
                 transformT1: Callable[..., Any] | None = None, 
                 transformT2: Callable[..., Any] | None = None, 
                 target_transform: Callable[..., Any] | None = None, 
                 is_valid_file: Callable[[str], bool] | None = None,
                 passOriginal: Callable[..., Any] | None = None, 
                ) -> None:
        super().__init__(root, loader, extensions, None, target_transform, is_valid_file)
        self.T1 = transformT1
        self.T2 = transformT2 if transformT2 is not None else transformT1
        self.PassOrigianlImage = passOriginal
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, target = self.samples[index]
                
        sample = self.loader(path) 
        if self.T1 is not None:
            sample1 = self.T1(sample)
            
        if self.T2 is not None:
            sample2 = self.T2(sample)

        if self.PassOrigianlImage is not None:
            return (sample1, sample2, self.PassOrigianlImage(sample)), target
        else:
            return (sample1, sample2), target
             
class RARP_PreprocessCreator():
    def removeBlackBorder(self, image):
        image = np.array(image)
        image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        crop = image[y : y + h, x : x + w]
        return crop
    
    def _crop(self, image:np, x:tuple, y:tuple):
        return image[y[0]:y[1], x[0]:x[1], :]

    def _sliceImage(self, img:np, size:float):
        h, w, _ = img.shape
        ul = self._crop(img, (0, math.trunc(w*size)), (0, math.trunc(h*size)))
        ur = self._crop(img, (math.trunc(w*size), w), (0, math.trunc(h*size)))
        dl = self._crop(img, (0, math.trunc(w*size)), (math.trunc(h*size), h))
        dr = self._crop(img, (math.trunc(w*size), w), (math.trunc(h*size), h))
        
        return [ul, ur, dl, dr]
    
    def __init__(self, 
                 RootPath = "", 
                 extension="tiff", 
                 FoldSeed=505, 
                 createFile=False, 
                 SavePath="", 
                 Fold=False, 
                 preResize=None, 
                 RGBGama=False, 
                 removeBlackBar=True,
                 SegImage:float=None) -> None:
        super().__init__()

        self.root = Path(RootPath)
        self.mean = 0.0
        self.std = 0.0
        self.test_DataSet = []
        self.val_DataSet = []
        self.train_DataSet = []

        self.nNVB = []
        self.yNVB = []

        lista = []

        dist = {
            'test':  [5, 2], #[4, 3], [10, 4],
            'val':   [8, 3], #[4, 3], [10, 4],
            'train': [28, 10] #[8, 6], [21, 7] 
        }
        
        sPath = Path(SavePath + f"_seed_{FoldSeed}")
        dumpImgs = sPath/"dump"
        dumpImgs.mkdir(parents=True, exist_ok=True)
        for i, file in enumerate(self.root.glob(f"**/*.{extension}")):
            tempImg = cv2.imread(str(file), cv2.IMREAD_COLOR)
            if removeBlackBar:
                tempImg = self.removeBlackBorder(tempImg)
            if RGBGama:
                tempImg = cv2.cvtColor(tempImg, cv2.COLOR_BGR2RGB)
            if preResize is not None:
                tempImg = cv2.resize(tempImg, preResize, interpolation=cv2.INTER_AREA)#Cambio de resolucion

            lista.append (np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))))
            
            if not (sPath/"train").exists():
                imgPath = dumpImgs/f"Img{i}.npy"
                np.save(imgPath, tempImg)
                #parImgClass = (tempImg.astype(float), 0 if file.parent.name == "NOT_NVB" else 1)
                parImgClass = (imgPath, 0 if file.parent.name == "NOT_NVB" else 1)

                if file.parent.name == "NOT_NVB":
                    self.nNVB.append(parImgClass)
                else:
                    self.yNVB.append(parImgClass)
                    
                if SegImage is not None:
                    for j, newImg in enumerate(self._sliceImage(tempImg, SegImage)):
                        imgPath = dumpImgs/f"Img{i}-{j}.npy"
                        np.save(imgPath, newImg)
                        #parImgClass = (tempImg.astype(float), 0 if file.parent.name == "NOT_NVB" else 1)
                        parImgClass = (imgPath, 0 if file.parent.name == "NOT_NVB" else 1)

                        if file.parent.name == "NOT_NVB":
                            self.nNVB.append(parImgClass)
                        else:
                            self.yNVB.append(parImgClass)

        temp = np.asarray(lista)
        self.mean = list(np.mean(temp, axis=tuple(range(temp.ndim-1))))
        self.std = list(np.std(temp, axis=tuple(range(temp.ndim-1))))
        
        if (sPath / "train").exists():
            self.test_DataSet = None
            self.train_DataSet = None
            self.val_DataSet = None
            self.nNVB = None
            self.yNVB = None
            self.SaveFold = sPath
            return

        if FoldSeed is not None:
            np.random.seed(FoldSeed)

        shuffle_nNVB = list(range(len(self.nNVB)))
        shuffle_NVB = list(range(len(self.yNVB)))

        np.random.shuffle(shuffle_nNVB)
        np.random.shuffle(shuffle_NVB)

        pasoN, pasoY = (dist['train'][0], dist['train'][1])

        self.train_DataSet = self.nNVB[:pasoN] + self.yNVB[:pasoY]

        nextPasoN, nextPasoY = (pasoN+dist['val'][0], pasoY+dist['val'][1])

        self.val_DataSet = self.nNVB[pasoN:nextPasoN] + self.yNVB[pasoY:nextPasoY]

        self.test_DataSet = self.nNVB[nextPasoN:] + self.yNVB[nextPasoY:]

        if createFile:
            if len(SavePath) == 0: 
                raise Exception("If createFile is True, SavePath must have a value")
            
            n=0
            for img, lb in self.train_DataSet:
                trainPath = sPath / "train" / ("NOT_NVB" if lb == 0 else "NVB") / f"img{n}"
                trainPath.parent.mkdir(parents=True, exist_ok=True)
                np.save(trainPath, defs.load_file(img))
                n += 1

            n=0 if not Fold else n
            for img, lb in self.val_DataSet:
                trainPath = sPath / ("val" if not Fold else "train") / ("NOT_NVB" if lb == 0 else "NVB") / f"img{n}"
                trainPath.parent.mkdir(parents=True, exist_ok=True)
                np.save(trainPath, defs.load_file(img))
                n += 1

            n=0 if not Fold else n
            for img, lb in self.test_DataSet:
                trainPath = sPath / ("test" if not Fold else "train") / ("NOT_NVB" if lb == 0 else "NVB") / f"img{n}"
                trainPath.parent.mkdir(parents=True, exist_ok=True)
                np.save(trainPath, defs.load_file(img))
                n += 1

            self.test_DataSet = None
            self.train_DataSet = None
            self.val_DataSet = None
            self.nNVB = None
            self.yNVB = None
            self.SaveFold = sPath
            
            for f in dumpImgs.iterdir():
                f.unlink()
            dumpImgs.rmdir()

class DataloaderToDataModule(LightningDataModule):
    """Converts a set of dataloaders into a lightning datamodule.

    Args:
        train_dataloader: Training dataloader
        val_dataloaders: Validation dataloader(s)
        test_dataloaders: Test dataloader(s)
    """

    def __init__(
        self,
        train_dataloader: DataLoader,
        val_dataloaders: Union[DataLoader, Sequence[DataLoader]],
    ) -> None:
        super().__init__()
        self._train_dataloader = train_dataloader
        self._val_dataloaders = val_dataloaders

    def train_dataloader(self) -> DataLoader:
        """Return training dataloader."""
        return self._train_dataloader

    def val_dataloader(self) -> Union[DataLoader, Sequence[DataLoader]]:
        """Return validation dataloader(s)."""
        return self._val_dataloaders

class RARP_DataModule(LightningDataModule):
    def __init__(self, 
                 KFold:bool=False, 
                 num_fold:int=5, 
                 shuffle: bool = False,
                 stratified: bool = False,
                 train_dataloader: DataLoader = None,
                 val_dataloader: DataLoader = None,
                 test_dataloader: DataLoader = None,
                 transforms = []) -> None:
        super().__init__()

        self.fold_index = 0
        self.splits = None
        self.Kfold_split = KFold
        self.num_folds = num_fold
        self.Folds = []
        self.shuffle = shuffle
        self.stratified = stratified
        self.label_extractor = lambda batch: batch[1]

        self.trainDL = train_dataloader
        self.valDL = val_dataloader 
        self.testDL = test_dataloader
        self.dataloader_settings = None
        self.transforms = transforms

        tempFoldsOrder = list(range(num_fold))
        for _ in range(num_fold):
            self.Folds.append(tempFoldsOrder)
            tempFoldsOrder = tempFoldsOrder[1:] + tempFoldsOrder[:1]

    def get_Current_Folds(self) -> list:
        return self.Folds[self.fold_index]
    
    def get_labels(self, dataloader: DataLoader) -> Optional[List]:
        """Try to extract the training labels (for classification problems) from the underlying training dataset."""
        # Try to extract labels from the dataset through labels attribute
        if hasattr(dataloader.dataset, "labels"):
            return dataloader.dataset.labels.tolist()

        # Else iterate and try to extract
        try:
            return torch.cat([self.label_extractor(batch) for batch in dataloader], dim=0).tolist()
        except Exception:
            return None
            
    def _split(self, a, n):
        k, m = divmod(len(a), n)
        return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
    
    def split_folds(self):
        if self.splits is None:
            self.splits = []
            self.FoldDataset = self.trainDL.dataset

            NO_NVB = []
            NVB = []
            for pos, label in enumerate(self.FoldDataset.targets):
                if label==0:
                    NO_NVB.append(pos) 
                else:
                    NVB.append(pos)

            if self.shuffle:
                np.random.shuffle(NO_NVB)
                np.random.shuffle(NVB)

            NO_NVB_Folds = list (self._split(NO_NVB, self.num_folds))
            NVB_Fols = list(self._split(NVB, self.num_folds))

            setst = [
                math.trunc(0.60 * self.num_folds), math.trunc(0.80 * self.num_folds), self.num_folds
            ]

            tempArray = [] ##### FIX ... que puede hacer los grupos mal sin los folds son difrentes de 5
            for pos, index in enumerate(self.Folds[self.fold_index]):
                if pos < setst[0]:
                    tempArray += NO_NVB_Folds[index] + NVB_Fols[index]
                elif pos < setst[1]:
                    self.splits.append(tempArray)
                    self.splits.append(NO_NVB_Folds[index] + NVB_Fols[index])
                else:
                    self.splits.append(NO_NVB_Folds[index] + NVB_Fols[index])

        
        
    def setup_folds(self):
        if self.splits is None:
            labels = None
            if self.stratified:
                labels = self.get_labels(self.trainDL)
                if labels is None:
                    raise ValueError(
                        "Tried to extract labels for stratified K folds but failed."
                        " Make sure that the dataset of your train dataloader either"
                        " has an attribute `labels` or that `label_extractor` attribute"
                        " is initialized correctly"
                    )
                splitter = StratifiedKFold(self.num_folds, shuffle=self.shuffle)
            else:
                splitter = KFold(self.num_folds, shuffle=self.shuffle)

            self.FoldDataset = self.trainDL.dataset

            self.splits = [split for split in splitter.split(range(len(self.FoldDataset)), y=labels)]

    def train_dataloader(self) -> DataLoader:
        if not self.Kfold_split:
            return self.trainDL
        else:
            self.split_folds()
            train_fold = RARP_Dataset(self.FoldDataset, self.splits[0], self.transforms[0])
            return DataLoader(train_fold, shuffle=True, **self.dataloader_setting)
        
    def val_dataloader(self) -> DataLoader:
        if not self.Kfold_split:
            return self.valDL
        else:
            self.split_folds()
            val_fold = RARP_Dataset(self.FoldDataset, self.splits[1], self.transforms[1])
            return DataLoader(val_fold, shuffle=False, **self.dataloader_setting)
        
    def test_dataloader(self) -> DataLoader:
        if not self.Kfold_split:
            return self.testDL
        else:
            self.split_folds()
            test_fold = RARP_Dataset(self.FoldDataset, self.splits[2], self.transforms[2])
            return DataLoader(test_fold, shuffle=False, **self.dataloader_setting)
    
    @property
    def dataloader_setting(self) -> dict:
        """Return the settings of the train dataloader."""
        if self.dataloader_settings is None:
            orig_dl = self.trainDL
            self.dataloader_settings = {
                "batch_size": orig_dl.batch_size,
                "num_workers": orig_dl.num_workers,
                "collate_fn": orig_dl.collate_fn,
                "pin_memory": orig_dl.pin_memory,
                "drop_last": orig_dl.drop_last,
                "timeout": orig_dl.timeout,
                "worker_init_fn": orig_dl.worker_init_fn,
                "prefetch_factor": orig_dl.prefetch_factor,
                "persistent_workers": orig_dl.persistent_workers,
            }
        return self.dataloader_settings

class RARP_Dataset(Dataset):
    def __init__(self, RARP_dataset, indices, transform=None) -> None:
        super().__init__()
        self.transform = transform
        self.samples = Subset(RARP_dataset, indices)
        self.targets = [s[1] for s in self.samples]
        self.classes = ["NO_NVB", "NVB"] ## **** Hacer esta parte correctamente */*** 

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img, label = self.samples[index]

        if self.transform is not None:
            img = self.transform(img)

        return (img, label)
    
class RARP_ChannelSwap(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
    def forward(self, img):
        # It is assumed that the input image is in RGB format.
        Channels = [0, 1, 2]
        np.random.shuffle(Channels)
        
        return img[Channels]
    
class RARP_Invert(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
    def forward(self, img):
        return 255 - img
    
class RARP_DINO_Augmentation():
    def __init__(
        self, 
        GloblaCropsScale=(0.4, 1), 
        LocalCropsScale=(0.05, 0.4), 
        NumLocalCrops:int=8, 
        Size:int=224, 
        device = None, 
        mean:float = None, 
        std:float = None,
        Tranform_0 = None,
        Init_Resize = (512,512)
    ) -> None:
        self.NumLocal_Crops= NumLocalCrops
       
        self.globalCrop1 = torch.nn.Sequential(
            T.Resize(Init_Resize, antialias=True, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomRotation(
                degrees=(-15, 15), 
                fill=5
            ),
            T.RandomResizedCrop(
                Size, 
                scale=GloblaCropsScale, 
                antialias=True,
                interpolation=T.InterpolationMode.BICUBIC
            ),
            #T.RandomHorizontalFlip(0.5),
            T.RandomErasing(0.2, value="random"),
            #T.RandomApply([
            #    RARP_ChannelSwap()
            #]),
            T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
            T.Normalize(mean, std)
        ).to(device)
        
        self.globalCrop2 = torch.nn.Sequential(
            T.Resize(Init_Resize, antialias=True, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomRotation(
                degrees=(-15, 15), 
                fill=5
            ),
            T.RandomResizedCrop(
                Size, 
                scale=GloblaCropsScale,
                antialias=True,
                interpolation=T.InterpolationMode.BICUBIC
            ),
            #T.RandomHorizontalFlip(0.5),
            T.RandomErasing(0.8, value="random"),
            T.RandomApply([
                RARP_ChannelSwap()
            ]),
            T.RandomApply([
                T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))
            ], 0.1),
            T.RandomApply([
                RARP_Invert()
            ], 0.2),
            T.Normalize(mean, std)
        ).to(device)
        
        self.local = torch.nn.Sequential(
            T.Resize(Init_Resize, antialias=True, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomRotation(
                degrees=(-5, 5), 
                fill=5
            ),
            T.RandomResizedCrop(
                Size, 
                scale=LocalCropsScale,
                antialias=True, 
                interpolation=T.InterpolationMode.BICUBIC
            ),
            #T.RandomHorizontalFlip(0.5),
            T.RandomErasing(0.1, value="random"),
            T.RandomApply([
                RARP_ChannelSwap()
            ], 0.5),
            T.RandomApply([
                T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))
            ], 0.5),
            T.Normalize(mean, std)
        ).to(device)
        
        InitResize = (256,256)        
        self.classification = torch.nn.Sequential(
            T.Resize(
                InitResize, 
                antialias=True, 
                interpolation=T.InterpolationMode.BICUBIC
            ),
            T.CenterCrop(224),    
            T.RandomAffine(
                degrees=(-5, 5), scale=(0.9, 1.1), fill=5
            ),
            T.RandomHorizontalFlip(1.0),
            T.Normalize(mean, std),
        ).to(device) if Tranform_0 == None else Tranform_0
                
    def __call__(self, img):
        all_Crops = []
        all_Crops.append(self.classification(img))
        all_Crops.append(self.globalCrop1(img))
        all_Crops.append(self.globalCrop2(img))
        
        all_Crops.extend([self.local(img) for _ in range(self.NumLocal_Crops)])
        
        return all_Crops

class RARP_DataSetType(Enum):
    train = 0
    val = 1
    test = 2

class RandomVideoDataset(torch.utils.data.IterableDataset):
    def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
        super().__init__()

        self.samples = get_samples(root)

        # Allow for temporal jittering
        if epoch_size is None:
            epoch_size = len(self.samples)
        self.epoch_size = epoch_size

        self.clip_len = clip_len
        self.frame_transform = frame_transform
        self.video_transform = video_transform

    def __iter__(self):
        for i in range(self.epoch_size):
            # Get random sample
            path, target = np.random.choice(self.samples)
            # Get video object
            vid = torchvision.io.VideoReader(path, "video")
            metadata = vid.get_metadata()
            video_frames = []  # video frame buffer

            # Seek and return frames
            max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
            start = np.random.uniform(0., max_seek)
            for frame in itertools.islice(vid.seek(start), self.clip_len):
                video_frames.append(self.frame_transform(frame['data']))
                current_pts = frame['pts']
            # Stack it into a tensor
            video = torch.stack(video_frames, 0)
            if self.video_transform:
                video = self.video_transform(video)
            output = {
                'path': path,
                'video': video,
                'target': target,
                'start': start,
                'end': current_pts}
            yield output

class VideoDataset(Dataset):
    def __init__(self, time_depth, mean, std, transform=None) -> None:
        super().__init__()

class RARPDataset(Dataset):
    def __init__(self, path_RARP_dataset:Path, path_RARP_Folds:Path=None, split=0, DataSetType:RARP_DataSetType="train", transform=None) -> None:
        super().__init__()
        self.samples = pd.read_csv(path_RARP_dataset, usecols=["id", "label", "class", "path"])
        self.classes = self.samples["class"].unique().tolist()
        self.targets = self.samples["label"].to_list()
        self.IDsplit = None
        self.transform = transform
                
        if path_RARP_Folds is not None:
            if split is None:
                raise Exception("Is required the split index to do Folds")
            
            arrFolds = np.load(path_RARP_Folds, allow_pickle=True)
            self.IDsplit = np.array_split(arrFolds, len(arrFolds)/3)[split][DataSetType.value]
            
            self.samples = self.samples.loc[self.samples["id"].isin(self.IDsplit)]
            self.targets = self.samples["label"]
            
    def __len__(self):
        return self.samples.shape[0]
    
    def __getitem__(self, index):
        ID = index
        if self.IDsplit is not None:
            ID = self.IDsplit[index]
            
        rs = self.samples.loc[self.samples["id"] == ID]
            
        img = defs.load_file_tensor(rs["path"].item())
        label = rs["label"].item()
        
        if self.transform is not None:
            img = self.transform(img)
            
        return img, label
