Newer
Older
RARP_server / Loaders.py
import os
import torch
from torchvision.datasets.folder import make_dataset
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, Subset
from lightning.pytorch import LightningDataModule
import lightning as L
import numpy as np
import cv2
from typing import Optional, Sequence, Tuple, Union, List, Any, Callable
from sklearn.model_selection import KFold, StratifiedKFold
import torchvision
import torchvision.transforms as T
import json
import math
import defs
import csv
import pandas as pd
from enum import Enum
import shutil
from ultralytics import YOLO
import itertools
import albumentations as A
from albumentations.pytorch import ToTensorV2
import decord
from torchcodec.decoders import SimpleVideoDecoder


class NVB_Classes(Enum):
    NOT_NVB = 0
    R_NVB = 1
    L_NVB = 2
    RL_NVB = 3
    
class NVB_Binary(Enum):
    NOT_NVB = 0
    NVB = 1

def _find_classes(dir):
    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
    classes.sort()
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

def get_samples(root, extensions=(".mp4", ".avi")):
    _, class_to_idx = _find_classes(root)
    return make_dataset(root, class_to_idx, extensions=extensions)

class RARP_DatasetCreator():
    def _removeBlackBorder(self, image):
        image = np.array(image)
        
        copyImg = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV)
        h = copyImg[:,:,0]
        mask = np.ones(h.shape, dtype=np.uint8) * 255
        th = (25, 175)
        mask[(h > th[0]) & (h < th[1])] = 0
        copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)
        resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)
            
        image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        crop = image[y : y + h, x : x + w]
        return crop
    
    def _crop(self, image:np, x:tuple, y:tuple):
        return image[y[0]:y[1], x[0]:x[1], :]
    
    def _sliceImage(self, img:np, size:float):
        h, w, _ = img.shape
        xw = math.trunc(w*size)
        xh = math.trunc(h*size)
        ul = self._crop(img, (0, xw), (0, xh))
        ur = self._crop(img, (w-xw, w), (0, xh))
        dl = self._crop(img, (0, xw), (h-xh, h))
        dr = self._crop(img, (w-xw, w), (h-xh, h))
        
        return [ul, ur, dl, dr]
    
    def _split(self, a, n):
        k, m = divmod(len(a), n)
        return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
    
    def _ROI_Mask(self, input_img:np.ndarray):
        transform = A.Compose(
            [
                A.Resize(224, 224, interpolation=cv2.INTER_CUBIC),
                A.Normalize(self.mean, self.std),
                ToTensorV2()
            ]
        )
        
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.Masked_roi_model.to(device)
        self.Masked_roi_model.eval()
        
        h, w, _ = input_img.shape
        sample = input_img.copy()
        
        input_img = transform(image=input_img)
        input_img = input_img["image"]
        input_img = input_img.repeat(1, 1, 1, 1)
        input_img = input_img.to(device)
        
        with torch.no_grad():
            mask = torch.sigmoid(self.Masked_roi_model(input_img))
            
        mask = mask[0].cpu().numpy().transpose((1, 2, 0))
        mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_CUBIC)
        mask = cv2.GaussianBlur(mask, (5, 5), 0)
        
        _, th_mask = cv2.threshold(mask, 0.5, 1, cv2.THRESH_BINARY)
        th_mask = cv2.morphologyEx(th_mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)))
        th_mask = th_mask.astype(np.uint8)
        
        contours = cv2.findContours(th_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        
        roi = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(roi)
        
        return cv2.bitwise_and(sample, sample, mask=th_mask)[y : y + h, x : x + w]
        
    
    def ROI_Extract_YOLO(self, YoloModel:YOLO, image, threshold=0.75):
        listROIs = []
        res = YoloModel(image, stream=True)
        for r in res:
            pred = r.boxes.conf.cpu().numpy()
            for i, box in enumerate(r.boxes.xyxy.cpu().numpy()):
                if pred[i] >= threshold:
                    box = box.astype(int)
                    x, y = (box[0], box[1])
                    xw, yh = (box[2], box[3])
                    listROIs.append(image[y:yh, x:xw])
                        
        return listROIs
    
    def __init__(
        self,
        RootPath = "", 
        extension="tiff", 
        FoldSeed=505, 
        createFile=False, 
        SavePath="", 
        Fold:int=None, 
        preResize=None, 
        RGBGama=False, 
        removeBlackBar=True,
        SegImage:float=None,
        Num_Img_Slices:int = 4,
        SegmentClass:int = None,
        colorSpace:int=None,
        ROI_Yolo:YOLO=None,
        thresholdYolo_Accuracy:float=0.75,
        Num_Labels:int = None,
        ROI_Mask:L.LightningModule=None
    ) -> None:
        
        root = Path(RootPath)
        lMean = []
        lStd = []
        NO_NVB = []
        NVB = []
        self.Num_Labels = None
        
        self.Masked_roi_model = ROI_Mask
        
        if Num_Labels is not None:
            MultiClasses = [[] for _ in range(Num_Labels)]
            self.Num_Labels = Num_Labels
                        
        if createFile:
            if len(SavePath) == 0: 
                raise Exception("If createFile is True, SavePath must have a value")
        
        sPath = Path(SavePath + f"_seed_{FoldSeed}")
        dumpImgs = sPath/"dump"
        dumpImgs.mkdir(parents=True, exist_ok=True)
        self.CVS_File = dumpImgs/"dataset.csv"
        if not self.CVS_File.exists():
            with open(self.CVS_File, "x", newline='') as csvfile: 
                writerOBJ = csv.writer(csvfile)
                writerOBJ.writerow(["id", "label", "class", "path", "name", "mean_1", "mean_2", "mean_3", "std_1", "std_2", "std_3"])
                id = 0
                for i, file in enumerate(root.glob(f"**/*.{extension}")):
                    tempImg = cv2.imread(str(file), cv2.IMREAD_COLOR)
                    if removeBlackBar and ROI_Yolo is None:
                        tempImg = self._removeBlackBorder(tempImg)
                    if RGBGama:
                        tempImg = cv2.cvtColor(tempImg, cv2.COLOR_BGR2RGB)
                    if colorSpace is not None:
                        tempImg = cv2.cvtColor(tempImg, colorSpace)
                    if preResize is not None:
                        tempImg = cv2.resize(tempImg, preResize, interpolation=cv2.INTER_AREA)#Cambio de resolucion
                    tempImgList = [tempImg]
                    
                    #ROI w/ YOLOv8
                    if ROI_Yolo is not None:
                        tempImgList = self.ROI_Extract_YOLO(ROI_Yolo, tempImg, thresholdYolo_Accuracy)
                    
                    for k, tempImg in enumerate(tempImgList):
                        idClass = NVB_Binary[file.parent.name].value if Num_Labels is None else NVB_Classes[file.parent.name].value
                        lineaCSV = [id, idClass, file.parent.name, (dumpImgs/f"Img{k}-{i}.npy").absolute(), file.name] 
                        lMean.append(np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))))
                        lStd.append(np.std(tempImg, axis=tuple(range(tempImg.ndim-1))))
                        lineaCSV += np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))).tolist()
                        lineaCSV += np.std(tempImg, axis=tuple(range(tempImg.ndim-1))).tolist()
                        writerOBJ.writerow(lineaCSV)
                        #lista.append (np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))))
                        np.save(dumpImgs/f"Img{k}-{i}.npy", tempImg)
                        if Num_Labels is None:
                            if lineaCSV[1] == 0:
                                NO_NVB.append(id)
                            else:
                                NVB.append(id)
                        else:
                            MultiClasses[idClass].append(id)
                        
                        id += 1
                        
                        if SegImage is not None:
                            if (SegmentClass is None) or (lineaCSV[1] == SegmentClass):
                                for j, newImg in enumerate(self._sliceImage(tempImg, SegImage)):
                                    if j == Num_Img_Slices:
                                        break
                                    imgPath = dumpImgs/f"Img{k}-{i}-{j}.npy"
                                    np.save(imgPath, newImg)
                                    idClass = NVB_Binary[file.parent.name].value if Num_Labels is None else NVB_Classes[file.parent.name].value
                                    lineaCSV = [id, idClass, file.parent.name, (dumpImgs/f"Img{k}-{i}-{j}.npy").absolute(), file.name] 
                                    lineaCSV += np.mean(newImg, axis=tuple(range(newImg.ndim-1))).tolist()
                                    lineaCSV += np.std(newImg, axis=tuple(range(newImg.ndim-1))).tolist()
                                    writerOBJ.writerow(lineaCSV)
                                    if Num_Labels is None:
                                        if lineaCSV[1] == 0:
                                            NO_NVB.append(id)
                                        else:
                                            NVB.append(id)
                                    else:
                                        MultiClasses[idClass].append(id)
                                    
                                    id += 1
                    
                csvfile.close()
                
            temp = np.asarray(lMean)
            self.mean = list(np.mean(temp, axis=tuple(range(temp.ndim-1))))
            temp = np.asarray(lStd)
            self.std = list(np.mean(temp, axis=tuple(range(temp.ndim-1))))
        else:
            data = pd.read_csv(self.CVS_File)
            self.mean = data[["mean_1", "mean_2", "mean_3"]].mean().to_list()
            self.std = data[["std_1", "std_2", "std_3"]].mean().to_list()
            if Num_Labels is None:
                NO_NVB = data.loc[data["label"] == 0]["id"].to_list()
                NVB = data.loc[data["label"] == 1]["id"].to_list()
            else:
                for i in range(Num_Labels):
                    MultiClasses[i] = data.loc[data["label"] == i]["id"].to_list()
            
        if Fold is not None:
            self.Splits = []      
            splitsToSave = []      
            tempFoldsOrder = list(range(Fold))
            for _ in range(Fold):
                self.Splits.append(tempFoldsOrder)
                tempFoldsOrder = tempFoldsOrder[1:] + tempFoldsOrder[:1]
                
            if FoldSeed is not None:
                np.random.seed(FoldSeed)
                if Num_Labels is None:
                    np.random.shuffle(NO_NVB)
                    np.random.shuffle(NVB)
                else:
                    for i in range(Num_Labels):
                        np.random.shuffle(MultiClasses[i])
            if Num_Labels is None:
                NO_NVB_Folds = list (self._split(NO_NVB, Fold))
                NVB_Fols = list(self._split(NVB, Fold))
            else:
                MultiClasses_Folds = []
                for i in range(Num_Labels):
                    MultiClasses_Folds.append(list(self._split(MultiClasses[i], Fold)))
            
            setst = [
                math.trunc(0.60 * Fold), math.trunc(0.20 * Fold), math.trunc(0.20 * Fold)
            ]
            
            for data in self.Splits:
                ultimo = 0 
                for s in setst:
                    tempArry = []
                    for fold in data[ultimo: ultimo + s]:
                        if Num_Labels is None:
                            tempArry += NO_NVB_Folds[fold] + NVB_Fols[fold]
                        else:
                            for i in range(Num_Labels):
                                tempArry += MultiClasses_Folds[i][fold]
                                
                    splitsToSave.append(tempArry)
                    ultimo += s
            self.Folds_File = dumpImgs/"Folds.npy"
            np.save(self.Folds_File, np.asarray(splitsToSave, dtype=object))
            
    def CreateClases(self):
        if not (self.CVS_File.parent/"FOLDS_CREATED").exists():
            root = self.CVS_File.parent.parent
            database = pd.read_csv(self.CVS_File, usecols=["id", "label", "class", "path"])
            
            for _, row in database.iterrows():
                pathOriginal = Path(row["path"])
                pathNuevo = root/"dataset"/("NO_NVB" if row["label"] == 0 else "NVB")/pathOriginal.name
                pathNuevo.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy(pathOriginal, pathNuevo)
                
            with open(self.CVS_File.parent/"FOLDS_CREATED", "x") as file:
                file.close()
                
            for f in self.CVS_File.parent.glob("*.npy"):
                f.unlink()
                
    def ExtractROI_YOLO(self, YOLOModel:YOLO, thresholdYolo:float = 0.70):
        if (self.CVS_File.parent/"FOLDS_CREATED").exists() and not (self.CVS_File.parent/"YOLO_ROI_CREATED").exists():
            root = self.CVS_File.parent.parent
            for img in root.glob(f"**/*.npy"):
                if img.name == "Folds.npy":
                    continue
                tempImgList = self.ROI_Extract_YOLO(YOLOModel, np.load(img), thresholdYolo)
                for k, tempImg in enumerate(tempImgList):
                    if k == 0:
                        np.save(img, tempImg)
                    #else:
                    #    newPath = img.parent/(img.name.replace(".npy", "") +f"-{k}.npy")
                    #    np.save(newPath, tempImg) 
                        
            with open(self.CVS_File.parent/"YOLO_ROI_CREATED", "x") as file:
                file.close()  
                        
    def CreateFolds(self):
        if not (self.CVS_File.parent/"FOLDS_CREATED").exists():
            root = self.CVS_File.parent.parent
            database = pd.read_csv(self.CVS_File, usecols=["id", "label", "class", "path"])
            arrFolds = np.load(self.Folds_File, allow_pickle=True)
            
            if self.Masked_roi_model is not None:
                database.reset_index()
                for _, row in database.iterrows():
                    img = np.load(Path(row["path"]))
                    np.save(Path(row["path"]), self._ROI_Mask(img))
            
            database.reset_index()
            for foldNum, splits in enumerate(np.array_split(arrFolds, len(arrFolds)/3)):
                foldPath = root/f"fold_{foldNum}"
                for datasetType, subSet in enumerate(splits):
                    SubfoldPath = foldPath/f"{RARP_DataSetType(datasetType).name}"
                    for _, row in database.loc[database["id"].isin(subSet)].iterrows():
                        PathOri = Path(row["path"])
                        if self.Num_Labels is not None:
                            folderName = f"{row['label']}_" + ("NO_NVB" if row["label"] == 0 else NVB_Classes(row["label"]).name)
                        else:
                            folderName = "NO_NVB" if row["label"] == 0 else "NVB" 
                        PathImg = SubfoldPath/folderName/PathOri.name
                        PathImg.parent.mkdir(parents=True, exist_ok=True)
                        shutil.copy(PathOri, PathImg)
                        
            with open(self.CVS_File.parent/"FOLDS_CREATED", "x") as file:
                file.close()
                
            for f in self.CVS_File.parent.glob("*.npy"):
                f.unlink()
                                
class RARP_DatasetFolder_ExtraData(torchvision.datasets.DatasetFolder):
    def __init__(self, 
                 root: str, 
                 loader: Callable[[str], Any], 
                 Extra_Data: pd.DataFrame,
                 Extra_Data_leg:int = 4,
                 extensions: Tuple[str, ...] | None = None, 
                 transform: Callable[..., Any] | None = None, 
                 target_transform: Callable[..., Any] | None = None, 
                 is_valid_file: Callable[[str], bool] | None = None
                ) -> None:
        super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)
        self.Extra_Data = Extra_Data
        self.Extra_Data_leg = Extra_Data_leg
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, target = self.samples[index]
        
        name = Path(path).name
        Extra_data = self.Extra_Data[self.Extra_Data["raw_name"] == name].values.flatten().tolist()[:self.Extra_Data_leg]
        Extra_data = torch.tensor(Extra_data)
        
        #Loadres is 2 values
        sample = self.loader(path) 
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (sample, Extra_data), target
    
class RARP_DatasetFolder_RoiExtractor(torchvision.datasets.DatasetFolder):
    def __init__(
        self, 
        root, 
        loader, 
        extensions = None, 
        transform = None, 
        target_transform = None, 
        is_valid_file = None,
        create_mask:bool = False
    ):
        super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)
        
        self.create_mask = create_mask
        
    def _removeBlackBorder_mask(self, image:np.ndarray, input_mask:np.ndarray):
        #image = np.array(image)
        
        copyImg = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2HSV)
        h = copyImg[:,:,0]
        mask = np.ones(h.shape, dtype=np.uint8) * 255
        th = (25, 175)
        mask[(h > th[0]) & (h < th[1])] = 0
        copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)
        resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)
            
        image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        
        return image[y : y + h, x : x + w], input_mask[y : y + h, x : x + w]
    
    def _catmull_rom_spline(self, P0, P1, P2, P3, n_points=20):
        points = []
        for t in np.linspace(0, 1, n_points):
            # Catmull-Rom formula
            t2 = t * t
            t3 = t2 * t
            x = 0.5 * ((2 * P1[0]) +
                    (-P0[0] + P2[0]) * t +
                    (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t2 +
                    (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t3)
            
            y = 0.5 * ((2 * P1[1]) +
                    (-P0[1] + P2[1]) * t +
                    (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t2 +
                    (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t3)
            
            points.append((x, y))
            
        return points

    def _catmull_rom_closed_loop(self, points, n_points=20):
        spline_points = []
        n = len(points)
        
        for i in range(n):
            P0 = points[(i - 1) % n]
            P1 = points[i]
            P2 = points[(i + 1) % n]
            P3 = points[(i + 2) % n]
            spline_points += self._catmull_rom_spline(P0, P1, P2, P3, n_points)
            
        return np.array(spline_points)
    
    def _create_mask_from_contour(self, spline_points:np, mask_size:Tuple = (0, 0)):
        smooth_curve_int = np.round(spline_points).astype(np.int32)
        mask = np.zeros(mask_size, dtype=np.uint8)
        
        return cv2.fillPoly(mask, [smooth_curve_int], 1)
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, _ = self.samples[index]
        
        pth = Path(path)
        newPth = pth.parent / f"{pth.name.split('.')[0]}.json"
        
        data = json.load(open(newPth))
        kpts = data["shapes"][0]["points"]
        
        img = self.loader(path) 
        if self.transform is not None:
            if not self.create_mask:
                sample = self.transform(image=img, keypoints=kpts)
                
                img = sample["image"]
                kpts = torch.tensor(sample["keypoints"])
                _, h, w = img.shape
                kpts = kpts / torch.tensor([h, w])
            else:
                h, w, _ = img.shape
                smood_perimeter = self._catmull_rom_closed_loop(kpts, n_points=15)
                roi_mask = self._create_mask_from_contour(smood_perimeter, (h, w))
                
                crop_img, crop_mask = self._removeBlackBorder_mask(img, roi_mask)
                crop_mask = crop_mask.astype(np.float32)
                
                sample = self.transform(image=crop_img, mask=crop_mask)
                img = sample["image"]
                kpts = sample["mask"]
            
            
        return img, kpts
        

class RARP_DatasetFolder_ExtraLabel(torchvision.datasets.DatasetFolder):
    def __init__(
        self, 
        root: str, 
        loader: Callable[[str], Any], 
        Extra_Data: pd.DataFrame,
        extensions: Tuple[str, ...] | None = None, 
        transform: Callable[..., Any] | None = None, 
        target_transform: Callable[..., Any] | None = None, 
        is_valid_file: Callable[[str], bool] | None = None
    ) -> None:
        super().__init__(root, loader, extensions, transform, target_transform, is_valid_file)
        self.Extra_Data = Extra_Data
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, _ = self.samples[index]
        
        name = Path(path).name
        Extra_data = [int(x) for x in str(self.Extra_Data[self.Extra_Data["raw_name"] == name]["encode_l1"].values[0]).split("|")]
        Extra_data = torch.tensor(Extra_data)
        
        #Loadres is 2 values
        sample = self.loader(path) 
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            Extra_data = self.target_transform(Extra_data)

        return sample, Extra_data

class RARP_DatasetFolder_DobleTransform(torchvision.datasets.DatasetFolder):
    def __init__(self, 
                 root: str, 
                 loader: Callable[[str], Any], 
                 extensions: Tuple[str, ...] | None = None, 
                 transformT1: Callable[..., Any] | None = None, 
                 transformT2: Callable[..., Any] | None = None, 
                 target_transform: Callable[..., Any] | None = None, 
                 is_valid_file: Callable[[str], bool] | None = None,
                 passOriginal: Callable[..., Any] | None = None, 
                ) -> None:
        super().__init__(root, loader, extensions, None, target_transform, is_valid_file)
        self.T1 = transformT1
        self.T2 = transformT2 if transformT2 is not None else transformT1
        self.PassOrigianlImage = passOriginal
        
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        path, target = self.samples[index]
                
        sample = self.loader(path) 
        if self.T1 is not None:
            sample1 = self.T1(sample)
            
        if self.T2 is not None:
            sample2 = self.T2(sample)

        if self.PassOrigianlImage is not None:
            return (sample1, sample2, self.PassOrigianlImage(sample)), target
        else:
            return (sample1, sample2), target
             
class RARP_PreprocessCreator():
    def removeBlackBorder(self, image):
        image = np.array(image)
        image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        crop = image[y : y + h, x : x + w]
        return crop
    
    def _crop(self, image:np, x:tuple, y:tuple):
        return image[y[0]:y[1], x[0]:x[1], :]

    def _sliceImage(self, img:np, size:float):
        h, w, _ = img.shape
        ul = self._crop(img, (0, math.trunc(w*size)), (0, math.trunc(h*size)))
        ur = self._crop(img, (math.trunc(w*size), w), (0, math.trunc(h*size)))
        dl = self._crop(img, (0, math.trunc(w*size)), (math.trunc(h*size), h))
        dr = self._crop(img, (math.trunc(w*size), w), (math.trunc(h*size), h))
        
        return [ul, ur, dl, dr]
    
    def __init__(self, 
                 RootPath = "", 
                 extension="tiff", 
                 FoldSeed=505, 
                 createFile=False, 
                 SavePath="", 
                 Fold=False, 
                 preResize=None, 
                 RGBGama=False, 
                 removeBlackBar=True,
                 SegImage:float=None) -> None:
        super().__init__()

        self.root = Path(RootPath)
        self.mean = 0.0
        self.std = 0.0
        self.test_DataSet = []
        self.val_DataSet = []
        self.train_DataSet = []

        self.nNVB = []
        self.yNVB = []

        lista = []

        dist = {
            'test':  [5, 2], #[4, 3], [10, 4],
            'val':   [8, 3], #[4, 3], [10, 4],
            'train': [28, 10] #[8, 6], [21, 7] 
        }
        
        sPath = Path(SavePath + f"_seed_{FoldSeed}")
        dumpImgs = sPath/"dump"
        dumpImgs.mkdir(parents=True, exist_ok=True)
        for i, file in enumerate(self.root.glob(f"**/*.{extension}")):
            tempImg = cv2.imread(str(file), cv2.IMREAD_COLOR)
            if removeBlackBar:
                tempImg = self.removeBlackBorder(tempImg)
            if RGBGama:
                tempImg = cv2.cvtColor(tempImg, cv2.COLOR_BGR2RGB)
            if preResize is not None:
                tempImg = cv2.resize(tempImg, preResize, interpolation=cv2.INTER_AREA)#Cambio de resolucion

            lista.append (np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))))
            
            if not (sPath/"train").exists():
                imgPath = dumpImgs/f"Img{i}.npy"
                np.save(imgPath, tempImg)
                #parImgClass = (tempImg.astype(float), 0 if file.parent.name == "NOT_NVB" else 1)
                parImgClass = (imgPath, 0 if file.parent.name == "NOT_NVB" else 1)

                if file.parent.name == "NOT_NVB":
                    self.nNVB.append(parImgClass)
                else:
                    self.yNVB.append(parImgClass)
                    
                if SegImage is not None:
                    for j, newImg in enumerate(self._sliceImage(tempImg, SegImage)):
                        imgPath = dumpImgs/f"Img{i}-{j}.npy"
                        np.save(imgPath, newImg)
                        #parImgClass = (tempImg.astype(float), 0 if file.parent.name == "NOT_NVB" else 1)
                        parImgClass = (imgPath, 0 if file.parent.name == "NOT_NVB" else 1)

                        if file.parent.name == "NOT_NVB":
                            self.nNVB.append(parImgClass)
                        else:
                            self.yNVB.append(parImgClass)

        temp = np.asarray(lista)
        self.mean = list(np.mean(temp, axis=tuple(range(temp.ndim-1))))
        self.std = list(np.std(temp, axis=tuple(range(temp.ndim-1))))
        
        if (sPath / "train").exists():
            self.test_DataSet = None
            self.train_DataSet = None
            self.val_DataSet = None
            self.nNVB = None
            self.yNVB = None
            self.SaveFold = sPath
            return

        if FoldSeed is not None:
            np.random.seed(FoldSeed)

        shuffle_nNVB = list(range(len(self.nNVB)))
        shuffle_NVB = list(range(len(self.yNVB)))

        np.random.shuffle(shuffle_nNVB)
        np.random.shuffle(shuffle_NVB)

        pasoN, pasoY = (dist['train'][0], dist['train'][1])

        self.train_DataSet = self.nNVB[:pasoN] + self.yNVB[:pasoY]

        nextPasoN, nextPasoY = (pasoN+dist['val'][0], pasoY+dist['val'][1])

        self.val_DataSet = self.nNVB[pasoN:nextPasoN] + self.yNVB[pasoY:nextPasoY]

        self.test_DataSet = self.nNVB[nextPasoN:] + self.yNVB[nextPasoY:]

        if createFile:
            if len(SavePath) == 0: 
                raise Exception("If createFile is True, SavePath must have a value")
            
            n=0
            for img, lb in self.train_DataSet:
                trainPath = sPath / "train" / ("NOT_NVB" if lb == 0 else "NVB") / f"img{n}"
                trainPath.parent.mkdir(parents=True, exist_ok=True)
                np.save(trainPath, defs.load_file(img))
                n += 1

            n=0 if not Fold else n
            for img, lb in self.val_DataSet:
                trainPath = sPath / ("val" if not Fold else "train") / ("NOT_NVB" if lb == 0 else "NVB") / f"img{n}"
                trainPath.parent.mkdir(parents=True, exist_ok=True)
                np.save(trainPath, defs.load_file(img))
                n += 1

            n=0 if not Fold else n
            for img, lb in self.test_DataSet:
                trainPath = sPath / ("test" if not Fold else "train") / ("NOT_NVB" if lb == 0 else "NVB") / f"img{n}"
                trainPath.parent.mkdir(parents=True, exist_ok=True)
                np.save(trainPath, defs.load_file(img))
                n += 1

            self.test_DataSet = None
            self.train_DataSet = None
            self.val_DataSet = None
            self.nNVB = None
            self.yNVB = None
            self.SaveFold = sPath
            
            for f in dumpImgs.iterdir():
                f.unlink()
            dumpImgs.rmdir()

class DataloaderToDataModule(LightningDataModule):
    """Converts a set of dataloaders into a lightning datamodule.

    Args:
        train_dataloader: Training dataloader
        val_dataloaders: Validation dataloader(s)
        test_dataloaders: Test dataloader(s)
    """

    def __init__(
        self,
        train_dataloader: DataLoader,
        val_dataloaders: Union[DataLoader, Sequence[DataLoader]],
    ) -> None:
        super().__init__()
        self._train_dataloader = train_dataloader
        self._val_dataloaders = val_dataloaders

    def train_dataloader(self) -> DataLoader:
        """Return training dataloader."""
        return self._train_dataloader

    def val_dataloader(self) -> Union[DataLoader, Sequence[DataLoader]]:
        """Return validation dataloader(s)."""
        return self._val_dataloaders

class RARP_DataModule(LightningDataModule):
    def __init__(self, 
                 KFold:bool=False, 
                 num_fold:int=5, 
                 shuffle: bool = False,
                 stratified: bool = False,
                 train_dataloader: DataLoader = None,
                 val_dataloader: DataLoader = None,
                 test_dataloader: DataLoader = None,
                 transforms = []) -> None:
        super().__init__()

        self.fold_index = 0
        self.splits = None
        self.Kfold_split = KFold
        self.num_folds = num_fold
        self.Folds = []
        self.shuffle = shuffle
        self.stratified = stratified
        self.label_extractor = lambda batch: batch[1]

        self.trainDL = train_dataloader
        self.valDL = val_dataloader 
        self.testDL = test_dataloader
        self.dataloader_settings = None
        self.transforms = transforms

        tempFoldsOrder = list(range(num_fold))
        for _ in range(num_fold):
            self.Folds.append(tempFoldsOrder)
            tempFoldsOrder = tempFoldsOrder[1:] + tempFoldsOrder[:1]

    def get_Current_Folds(self) -> list:
        return self.Folds[self.fold_index]
    
    def get_labels(self, dataloader: DataLoader) -> Optional[List]:
        """Try to extract the training labels (for classification problems) from the underlying training dataset."""
        # Try to extract labels from the dataset through labels attribute
        if hasattr(dataloader.dataset, "labels"):
            return dataloader.dataset.labels.tolist()

        # Else iterate and try to extract
        try:
            return torch.cat([self.label_extractor(batch) for batch in dataloader], dim=0).tolist()
        except Exception:
            return None
            
    def _split(self, a, n):
        k, m = divmod(len(a), n)
        return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
    
    def split_folds(self):
        if self.splits is None:
            self.splits = []
            self.FoldDataset = self.trainDL.dataset

            NO_NVB = []
            NVB = []
            for pos, label in enumerate(self.FoldDataset.targets):
                if label==0:
                    NO_NVB.append(pos) 
                else:
                    NVB.append(pos)

            if self.shuffle:
                np.random.shuffle(NO_NVB)
                np.random.shuffle(NVB)

            NO_NVB_Folds = list (self._split(NO_NVB, self.num_folds))
            NVB_Fols = list(self._split(NVB, self.num_folds))

            setst = [
                math.trunc(0.60 * self.num_folds), math.trunc(0.80 * self.num_folds), self.num_folds
            ]

            tempArray = [] ##### FIX ... que puede hacer los grupos mal sin los folds son difrentes de 5
            for pos, index in enumerate(self.Folds[self.fold_index]):
                if pos < setst[0]:
                    tempArray += NO_NVB_Folds[index] + NVB_Fols[index]
                elif pos < setst[1]:
                    self.splits.append(tempArray)
                    self.splits.append(NO_NVB_Folds[index] + NVB_Fols[index])
                else:
                    self.splits.append(NO_NVB_Folds[index] + NVB_Fols[index])

        
        
    def setup_folds(self):
        if self.splits is None:
            labels = None
            if self.stratified:
                labels = self.get_labels(self.trainDL)
                if labels is None:
                    raise ValueError(
                        "Tried to extract labels for stratified K folds but failed."
                        " Make sure that the dataset of your train dataloader either"
                        " has an attribute `labels` or that `label_extractor` attribute"
                        " is initialized correctly"
                    )
                splitter = StratifiedKFold(self.num_folds, shuffle=self.shuffle)
            else:
                splitter = KFold(self.num_folds, shuffle=self.shuffle)

            self.FoldDataset = self.trainDL.dataset

            self.splits = [split for split in splitter.split(range(len(self.FoldDataset)), y=labels)]

    def train_dataloader(self) -> DataLoader:
        if not self.Kfold_split:
            return self.trainDL
        else:
            self.split_folds()
            train_fold = RARP_Dataset(self.FoldDataset, self.splits[0], self.transforms[0])
            return DataLoader(train_fold, shuffle=True, **self.dataloader_setting)
        
    def val_dataloader(self) -> DataLoader:
        if not self.Kfold_split:
            return self.valDL
        else:
            self.split_folds()
            val_fold = RARP_Dataset(self.FoldDataset, self.splits[1], self.transforms[1])
            return DataLoader(val_fold, shuffle=False, **self.dataloader_setting)
        
    def test_dataloader(self) -> DataLoader:
        if not self.Kfold_split:
            return self.testDL
        else:
            self.split_folds()
            test_fold = RARP_Dataset(self.FoldDataset, self.splits[2], self.transforms[2])
            return DataLoader(test_fold, shuffle=False, **self.dataloader_setting)
    
    @property
    def dataloader_setting(self) -> dict:
        """Return the settings of the train dataloader."""
        if self.dataloader_settings is None:
            orig_dl = self.trainDL
            self.dataloader_settings = {
                "batch_size": orig_dl.batch_size,
                "num_workers": orig_dl.num_workers,
                "collate_fn": orig_dl.collate_fn,
                "pin_memory": orig_dl.pin_memory,
                "drop_last": orig_dl.drop_last,
                "timeout": orig_dl.timeout,
                "worker_init_fn": orig_dl.worker_init_fn,
                "prefetch_factor": orig_dl.prefetch_factor,
                "persistent_workers": orig_dl.persistent_workers,
            }
        return self.dataloader_settings

class RARP_Dataset(Dataset):
    def __init__(self, RARP_dataset, indices, transform=None) -> None:
        super().__init__()
        self.transform = transform
        self.samples = Subset(RARP_dataset, indices)
        self.targets = [s[1] for s in self.samples]
        self.classes = ["NO_NVB", "NVB"] ## **** Hacer esta parte correctamente */*** 

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img, label = self.samples[index]

        if self.transform is not None:
            img = self.transform(img)

        return (img, label)
    
class RARP_ChannelSwap(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
    def forward(self, img):
        # It is assumed that the input image is in RGB format.
        Channels = [0, 1, 2]
        np.random.shuffle(Channels)
        
        return img[Channels]
    
class RARP_RandomPatchMask(torch.nn.Module):
    def __init__(self, patch_size=16, mask_ratio=0.75, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        
    def forward(self, img):
        pass # TODO esta como tensor, hacer el codigo para que funciones con solo una imagen no como batch
    
class RARP_Invert(torch.nn.Module):
    def __init__(self, max_val_pixel=255, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.max_val = max_val_pixel
        
    def forward(self, img):
        return self.max_val - img
    
class RARP_MAE_Augmentation():
    def __init__(
        self,
        GloblaCropsScale=(0.4, 1), 
        Size:int=224, 
        device = None, 
        mean:float = None, 
        std:float = None,
        Init_Resize = (512,512),
        Tranform_0 = None
    ):
        self.globalCrop = torch.nn.Sequential(
            T.CenterCrop(Init_Resize),
            T.RandomRotation(
                degrees=(-15, 15), 
                fill=5
            ),
            T.RandomResizedCrop(
                Size, 
                scale=GloblaCropsScale, 
                antialias=True,
                interpolation=T.InterpolationMode.BICUBIC
            ),
            T.RandomHorizontalFlip(0.5),
            T.RandomSolarize(0.5, p=0.7),
            T.RandomApply([
                RARP_ChannelSwap()
            ], p=0.7),
            T.RandomApply([
                T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))
            ], 0.4),
            T.RandomApply([
                RARP_Invert(max_val_pixel=1.0)
            ], 0.3),
            T.Normalize(mean, std)
        ).to(device)
        
        self.original_crop = torch.nn.Sequential(
            T.CenterCrop(Init_Resize),
            T.Resize((Size, Size), antialias=True, interpolation=T.InterpolationMode.BICUBIC),
            T.Normalize(mean, std)
        ).to(device) if Tranform_0 is None else Tranform_0
        
    def __call__(self, img):
        all_Crops = []
        all_Crops.append(self.original_crop(img))
        all_Crops.append(self.globalCrop(img))
        
        return all_Crops
     
class RARP_DINO_Augmentation():
    def __init__(
        self, 
        GloblaCropsScale=(0.4, 1), 
        LocalCropsScale=(0.05, 0.4), 
        NumLocalCrops:int=8, 
        Size:int=224, 
        device = None, 
        mean:float = None, 
        std:float = None,
        Tranform_0 = None,
        Init_Resize = (512,512)
    ) -> None:
        self.NumLocal_Crops= NumLocalCrops
       
        self.globalCrop1 = torch.nn.Sequential(
            T.CenterCrop(Init_Resize),
            #T.Resize(Init_Resize, antialias=True, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomRotation(
                degrees=(-15, 15), 
                fill=5
            ),
            T.RandomResizedCrop(
                Size, 
                scale=GloblaCropsScale, 
                antialias=True,
                interpolation=T.InterpolationMode.BICUBIC
            ),
            #T.RandomHorizontalFlip(0.5),
            T.RandomErasing(0.2, value="random"),
            #T.RandomApply([
            #    RARP_ChannelSwap()
            #]),
            T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
            T.Normalize(mean, std)
        ).to(device)
        
        self.globalCrop2 = torch.nn.Sequential(
            T.CenterCrop(Init_Resize),
            #T.Resize(Init_Resize, antialias=True, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomRotation(
                degrees=(-15, 15), 
                fill=5
            ),
            T.RandomResizedCrop(
                Size, 
                scale=GloblaCropsScale,
                antialias=True,
                interpolation=T.InterpolationMode.BICUBIC
            ),
            #T.RandomHorizontalFlip(0.5),
            T.RandomErasing(0.8, value="random"),
            T.RandomApply([
                RARP_ChannelSwap()
            ]),
            T.RandomApply([
                T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))
            ], 0.3),
            T.RandomApply([
                RARP_Invert(max_val_pixel=1.0)
            ], 0.4),
            T.Normalize(mean, std)
        ).to(device)
        
        self.local = torch.nn.Sequential(
            T.CenterCrop(Init_Resize),
            #T.Resize(Init_Resize, antialias=True, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomRotation(
                degrees=(-5, 5), 
                fill=5
            ),
            T.RandomResizedCrop(
                Size, 
                scale=LocalCropsScale,
                antialias=True, 
                interpolation=T.InterpolationMode.BICUBIC
            ),
            T.RandomHorizontalFlip(0.5),
            T.RandomSolarize(0.5),
            T.RandomErasing(0.1, value="random"),
            T.RandomApply([
                RARP_ChannelSwap()
            ], 0.3),
            T.RandomApply([
                T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))
            ], 0.5),
            T.Normalize(mean, std)
        ).to(device)
        
        InitResize = (256,256)        
        self.classification = torch.nn.Sequential(
            T.Resize(
                InitResize, 
                antialias=True, 
                interpolation=T.InterpolationMode.BICUBIC
            ),
            T.CenterCrop(224),    
            T.RandomAffine(
                degrees=(-5, 5), scale=(0.9, 1.1), fill=5
            ),
            T.RandomHorizontalFlip(1.0),
            T.Normalize(mean, std),
        ).to(device) if Tranform_0 == None else Tranform_0
                
    def __call__(self, img):
        all_Crops = []
        all_Crops.append(self.classification(img))
        all_Crops.append(self.globalCrop1(img))
        all_Crops.append(self.globalCrop2(img))
        
        all_Crops.extend([self.local(img) for _ in range(self.NumLocal_Crops)])
        
        return all_Crops

class RARP_DINO_Albumentations():
    def __init__(
        self, 
        GloblaCropsScale=(0.4, 1), 
        LocalCropsScale=(0.05, 0.4), 
        NumLocalCrops = 8, 
        Size = 224, 
        device=None, 
        mean = None, 
        std = None, 
        Tranform_0=None, 
        Init_Resize=(512, 512),
        Seed=505
    ):
        self.NumLocal_Crops= NumLocalCrops
       
        self.globalCrop1 = A.Compose([
            #A.Resize(Init_Resize[0], Init_Resize[1], interpolation=cv2.INTER_CUBIC),
            A.Affine(scale=GloblaCropsScale, rotate=(-15, 15), interpolation=cv2.INTER_CUBIC),
            A.RandomCrop(Size, Size),
            A.HorizontalFlip(0.6),
            A.ColorJitter(brightness=1.1, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),
            A.ToGray(p=0.4),
            A.GaussianBlur(p=1),
            A.RandomFog(p=0.5),
            A.Normalize(mean, std),
            ToTensorV2()
        ], seed=Seed)
        
        self.globalCrop2 = A.Compose([
            #A.Resize(Init_Resize[0], Init_Resize[1], interpolation=cv2.INTER_CUBIC),
            A.Affine(scale=GloblaCropsScale, rotate=(-15, 15), interpolation=cv2.INTER_CUBIC),
            A.RandomCrop(Size, Size),
            A.HorizontalFlip(0.5),
            A.ColorJitter(brightness=1.1, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),
            A.ToGray(p=0.4),
            A.GaussianBlur(p=1),
            A.Solarize(p=0.3),
            A.RandomFog(p=0.5),
            A.Normalize(mean, std),
            ToTensorV2()
        ], seed=Seed)
        
        self.local = A.Compose([
            #A.Resize(Init_Resize[0], Init_Resize[1], interpolation=cv2.INTER_CUBIC),
            A.Affine(scale=LocalCropsScale, rotate=(-15, 15), interpolation=cv2.INTER_CUBIC),
            A.RandomCrop(Size, Size),
            A.ColorJitter(brightness=1.1, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),
            A.Solarize(p=0.3),
            A.RandomFog(p=0.5),
            A.Normalize(mean, std),
            ToTensorV2()
        ], seed=Seed)
                      
        self.classification = Tranform_0
                
    def __call__(self, img):
        all_Crops = []
                        
        all_Crops.append(self.classification(img))
        all_Crops.append(self.globalCrop1(image=img)["image"])
        all_Crops.append(self.globalCrop2(image=img)["image"])
        
        all_Crops.extend([self.local(image=img)["image"] for _ in range(self.NumLocal_Crops)])
        
        return all_Crops

class RARP_Windowed_Video_MIL_Dataset(Dataset):
    def __init__(
        self,
        items,
        num_windows:int = 8,
        window_length:int = 64,
        transform=None,
        transform_frame=None,
        ext:str = "npy",
        key_frames:bool = False,
        key_frame_transform=None,
        train_mode:bool=False,
        key_frame_only:bool=False,
        load_key_frame_cache:bool=False,
        Fold_index:int = None,
        no_norm_video:bool= False
    ):
        super().__init__()
        
        self.samples = items
        self.W = num_windows
        self.L = window_length
        self.clip_transform = transform
        self.frame_transform = transform_frame
        self.key_frame_transform = key_frame_transform
        self.ext_file = ext
        self.load_key_frames = key_frames
        self.train = train_mode
        self.load_only_key_frame = key_frame_only
        self.load_key_frame_features_from_cache = load_key_frame_cache
        self.FOLD = Fold_index
        self.NO_Normalization = no_norm_video
        
        self.frame_counts = []
        self.arrays = []
        self.case_index = {}
        for indx, item in enumerate(items):
            vpath = Path(item["path"]).resolve().with_suffix(f".{self.ext_file}")
            self.case_index[item["case"]] = indx
            arr = np.load(vpath, mmap_mode="r")
            self.arrays.append(arr)
            self.frame_counts.append(arr.shape[0])
            
    def __len__(self):
        return len(self.samples)
    
    def _removeBlackBorder(self, image):
        copyImg = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h = copyImg[:,:,0]
        mask = np.ones(h.shape, dtype=np.uint8) * 255
        th = (25, 175)
        mask[(h > th[0]) & (h < th[1])] = 0
        copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)
        resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)
            
        image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        
        crop = image[y : y + h, x : x + w]
        
        return crop
    
    def _sample_windows_train (self, T):
        if T <= self.L:
            return [(0, T)] * self.W
        
        stride = (T - self.L) / max(self.W - 1, 1)
        starts = []
        for i in range(self.W):
            b = int(i * stride)
            jitter = int(0.1 * self.L)
            l = max(0, b - jitter)
            h = min(T - self.L, b + jitter)
            starts.append(np.random.randint(l, h) if h > l else l)
            
        return [(s, s + self.L) for s in starts]
    
    def _sample_windows_val (self, T):
        if T <= self.L:
            return [(0, T)] * self.W
        
        stride = (T - self.L) / max(self.W - 1, 1)
        starts = []
        for i in range(self.W):
            s = int(round(i * stride))
            s = min(s, T - self.L)       # safety clamp
            starts.append(s)
            
        return [(s, s + self.L) for s in starts]
    
    def _load_frames_inrage(self, video_index:int, start:int, end:int) -> torch.Tensor:
        arr = self.arrays[video_index]
        clip_np = arr[start:end].copy()
        clip = torch.from_numpy(clip_np).float() / (255.0 if not self.NO_Normalization else 1)
        
        if clip.shape[0] < self.L:
            pad_len = self.L - clip.shape[0]
            pad = torch.zeros((pad_len, ) + clip.shape[1:], dtype=clip.dtype, device=clip.device)
            clip = torch.cat([clip, pad], dim=0)
            
        return clip, clip_np.shape[0]
    
    def __getitem__(self, idx):
        item = self.samples[idx]
        label = torch.tensor(item["label"], dtype=torch.int)
        key_frame = None
        soft_labels = None
        
        if self.load_key_frames:
            if not self.load_key_frame_features_from_cache:
                key_frame_img = cv2.imread(str(Path(item["key_frame"]).resolve()), cv2.IMREAD_COLOR)
                key_frame_img = self._removeBlackBorder(key_frame_img)
                
                key_frame = torch.from_numpy(key_frame_img.transpose((2, 0, 1))).float()
                
                if self.key_frame_transform is not None:
                    key_frame = self.key_frame_transform(key_frame)
            else:
                path_cache = Path(item["path"]).resolve().parent / "cache"
                file_name = f"F{self.FOLD}_{item['case']}.npz"
                
                cached_features = np.load((path_cache / file_name))
                
                key_frame = torch.from_numpy(cached_features["img_features"]).float()
                soft_labels = torch.from_numpy(cached_features["soft_label"]).float()           
        
        if self.load_only_key_frame:
            return key_frame, torch.tensor(item["case"], dtype=torch.int)
        
        T_video = self.frame_counts[idx]
        wind_idx = self._sample_windows_train(T_video) if self.train else self._sample_windows_val(T_video)
        
        window_tensors = []
        window_masks = []
        
        for (s, e) in wind_idx:
            clip, Lp = self._load_frames_inrage(idx, s, e)
            
            mask = torch.zeros(self.L, dtype=torch.bool)
            mask[:Lp] = True
            
            if self.clip_transform is not None:
                clip = self.clip_transform(clip)
                
            if self.frame_transform is not None:
                for t in range(clip.shape[0]):
                    clip[t] = self.frame_transform(clip[t])
                    
            window_tensors.append(clip)
            window_masks.append(mask)
            
        winds = torch.stack(window_tensors, dim=0)
        masks = torch.stack(window_masks, dim=0)
        
        starts = [s for (s, e) in wind_idx]
        ends   = [e for (s, e) in wind_idx]
        win_i  = list(range(len(wind_idx)))
        
        meta = {
            "case_id": item["case"],
            "win_idx": torch.tensor(win_i, dtype=torch.int),        
            "win_start": torch.tensor(starts, dtype=torch.int),     
            "win_end": torch.tensor(ends, dtype=torch.int), 
        }
        
        if not self.load_key_frames:
            return winds, label, masks, meta
        else:
            if not self.load_key_frame_features_from_cache:
                return winds, label, masks, key_frame, meta
            else:
                return winds, label, masks, key_frame, soft_labels, meta #key_frame = image features
            

class RARP_Windowed_Video_frames_Dataset(Dataset):
    def _sliding_windows(self, total_frames:int,  conver_last=True):
        if total_frames <= self.L:
            return [0]
        
        starts = list(range(0, total_frames - self.L + 1, self.S))
        
        if conver_last and (starts[-1] != total_frames - self.L):
            starts.append(total_frames - self.L)
            
        return starts
    
    def _random_window_indices(self, total_frames: int):
        if total_frames <= self.L:
            return list(range(total_frames))
        
        start = np.random.randint(0, total_frames - self.L)
        return list(range(start, start + self.L))
    
    def _removeBlackBorder(self, image):
        copyImg = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h = copyImg[:,:,0]
        mask = np.ones(h.shape, dtype=np.uint8) * 255
        th = (25, 175)
        mask[(h > th[0]) & (h < th[1])] = 0
        copyImg = cv2.cvtColor(copyImg, cv2.COLOR_HSV2BGR)
        resROI = cv2.bitwise_and(copyImg, copyImg, mask=mask)
            
        image_gray = cv2.cvtColor(resROI, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        morph = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        contours = cv2.findContours(morph, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]
        bigCont = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(bigCont)
        
        crop = image[y : y + h, x : x + w]
        
        return crop
    
    def _load_frames_inrage(self, video_index:int, start:int, end:int) -> torch.Tensor:
        arr = self.arrays[video_index]
        clip_np = arr[start:end].copy()
        clip = torch.from_numpy(clip_np).float() / 255.0
        
        if clip.shape[0] < self.L:
            pad_len = self.L - clip.shape[0]
            pad = torch.zeros((pad_len, ) + clip.shape[1:], dtype=clip.dtype, device=clip.device)
            clip = torch.cat([clip, pad], dim=0)
            
        return clip, clip_np.shape[0]
    
    def __init__(
        self,
        items,
        resize=(360, 640),
        train_mode:bool = True,
        window_length:int = 64,
        stride:int = 32, # 50% overlap
        multi_label=False,
        transform=None,
        transform_frame=None,
        mean=[], 
        std=[],
        k_windows:int = 1,
        ext:str = "npy",
        key_frames:bool = False,
        key_frame_transform=None
    ):
        super().__init__()
        
        assert k_windows >= 1, "Not posible to have less than 1 windows sampling"
        
        self.samples = items
        self.mode = "train" if train_mode else "eval"
        self.L = window_length
        self.S = stride
        self.resize = resize
        self.transform = transform
        self.transform_by_frame = transform_frame
        self.key_frame_transform = key_frame_transform
        self.multi_label = multi_label
        self.mean = mean
        self.std = std
        self.k_wind = k_windows
        self.ext_file = ext
        self.load_key_frames = key_frames
        
        self.frame_counts = []
        self.arrays = []
        for item in items:
            vpath = Path(item["path"]).resolve().with_suffix(f".{self.ext_file}")
            arr = np.load(vpath, mmap_mode="r")
            self.arrays.append(arr)
            self.frame_counts.append(arr.shape[0])
        
        if self.mode == "eval":
            self.index_map = []
            for idx, item in enumerate(self.samples):
                T_total = self.frame_counts[idx] # count the amount of frames 
                for start in self._sliding_windows(T_total):
                    true_len = min(self.L, T_total - start)
                    self.index_map.append((idx, start, true_len))
        else:
            self.num_videos = len(items)
                    
    def __len__(self):
        if self.mode == "train":
            return self.num_videos * self.k_wind
        else:
            return len(self.index_map)

    def __getitem__(self, index):
        if self.mode == "train":
            vid_index = index % self.num_videos
            item = self.samples[vid_index]
            #frames_path = Path(item["path"]).resolve()
            label = torch.tensor(item["label"], dtype=torch.float32)
            T_total = self.frame_counts[vid_index] 
                                   
            if T_total <= self.L:
                start = 0
            else:
                start = np.random.randint(0, T_total - self.L)
                
            end = start + self.L
            
            clip, true_len = self._load_frames_inrage(vid_index, start, end)
            
            if self.transform is not None:
                clip = self.transform(clip)
                
            if self.transform_by_frame is not None:
                for t in range(clip.shape[0]):
                    clip[t] = self.transform_by_frame(clip[t])           
                        
            mask = torch.zeros(self.L, dtype=torch.bool)
            mask[:true_len] = True
            
            meta = {
                "video_idx": vid_index,
                "start": start
            }
            
            if not self.load_key_frames:
                return clip, label, mask, meta
            else:
                key_frame_img = cv2.imread(str(Path(item["key_frame"]).resolve()), cv2.IMREAD_COLOR)
                key_frame_img = self._removeBlackBorder(key_frame_img)
                
                key_frame = torch.from_numpy(key_frame_img.transpose((2, 0, 1))).float()
                
                if self.key_frame_transform is not None:
                    key_frame = self.key_frame_transform(key_frame)
                
                return clip, label, mask, meta, key_frame
        else: #eval loading
            vi, start, true_len = self.index_map[index]
            
            item = self.samples[vi]
            #frames_path = Path(item["path"]).resolve()
            label = torch.tensor(item["label"], dtype=torch.float32)
            #T_total = self.frame_counts[vi]
            
            end = start + self.L
            
            clip, _ = self._load_frames_inrage(vi, start, end)
            
            if self.transform is not None:
                clip = self.transform(clip)
             
            mask = torch.zeros(self.L, dtype=torch.bool)
            mask[:true_len] = True
            
            meta = {
                "video_idx": vi,
                "start": start,
                "true_len": true_len,
            }
            if not self.load_key_frames:
                return clip, label, mask, meta
            else:
                key_frame_img = cv2.imread(str(Path(item["key_frame"]).resolve()), cv2.IMREAD_COLOR)
                key_frame_img = self._removeBlackBorder(key_frame_img)
                
                key_frame = torch.from_numpy(key_frame_img.transpose((2, 0, 1))).float()
                
                if self.key_frame_transform is not None:
                    key_frame = self.key_frame_transform(key_frame)
                
                return clip, label, mask, meta, key_frame

class RARP_Video_Dataset(Dataset):
    def __init__(self, items, size_hw=(224, 224), crop = None, target_T=600, decode_resize=None, transform=None, transform_frame:bool=True, mean:torch.Tensor=None, std:torch.Tensor=None):
        super().__init__()
        
        self.transform  = transform
        self.items = items
        self.size_hw = size_hw
        self.crop_size = crop
        self.target_T = target_T
        self.decode_resize = decode_resize  # (w, h) or None
        self.transform_frame_by_frame = transform_frame
        
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1) if mean is None else mean
        self.std  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1) if std is None else std
        
    def __len__(self):
        return len(self.items)
    
    def _frames_to_tensor(self, frames_hwc_uint8):
        """frames_hwc_uint8: [T, H, W, 3] -> [T, C, H, W] float32 normalized"""
        T1 = frames_hwc_uint8.shape[0]
        out = []
        for t in range(T1):
            if self.crop_size is not None:
                X, Y, W, H = self.crop_size
                frame = frames_hwc_uint8[t][Y:Y+H, X:X+W]
            else:
                frame = frames_hwc_uint8[t]
            img = torch.from_numpy(frame).permute(2,0,1).float() # [C,H,W]
            img = img[[2, 1, 0], :, : ] #RGB2BGR
            img = torchvision.transforms.functional.resize(img, self.size_hw, antialias=True, interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
            if self.transform is not None and self.transform_frame_by_frame:
                img = self.transform(img)
            out.append(img)
        x = torch.stack(out, dim=0)  # [T,C,H,W]
        if self.transform is not None and not self.transform_frame_by_frame:
            x = self.transform(x)
        x = (x - self.mean) / self.std
        return x
    
    def __getitem__(self, idx):
        rec = self.items[idx]
        path = rec["path"]
        label = torch.tensor(rec["label"], dtype=torch.float32)

        # Optional decode-time resize for speed (decord takes width, height)
        if self.decode_resize is not None:
            w, h = self.decode_resize
            vr = decord.VideoReader(path, ctx=decord.cpu(0), width=w, height=h)
        else:
            vr = decord.VideoReader(path, ctx=decord.cpu(0))

        n = len(vr)  # number of frames in this 20s clip
        # Uniformly sample exactly target_T indices across [0, n-1]
        idxs = np.linspace(0, max(n-1, 0), num=self.target_T, dtype=np.int64)
        frames = vr.get_batch(idxs).asnumpy()  # [T, H, W, 3] uint8

        video = self._frames_to_tensor(frames)      # [T, C, H, W]

        return (video, label)
        

class RARP_VideoFrame_Dataset(Dataset):
    def __init__(self, video_path, transform=None):
        super().__init__()
        self.video_path = video_path
        self.transform  = transform
        self.length     = None
        self.reader     = None
        
    def __len__(self):
        if self.length is None:
            _tmp = decord.VideoReader(self.video_path)
            self.length = len(_tmp)
        return self.length
    
    def __getitem__(self, index):
        if self.reader is None:
            self.reader = decord.VideoReader(self.video_path)
            
        frame_img = self.reader[index].asnumpy()
        frame_img = torch.from_numpy(frame_img.astype(float).transpose((2, 0, 1))) / 255
                
        if self.transform is not None:
            frame_img = self.transform(frame_img)
            
        return frame_img
    
        

class RARP_DataSetType(Enum):
    train = 0
    val = 1
    test = 2

class RandomVideoDataset(torch.utils.data.IterableDataset):
    def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
        super().__init__()

        self.samples = get_samples(root)

        # Allow for temporal jittering
        if epoch_size is None:
            epoch_size = len(self.samples)
        self.epoch_size = epoch_size

        self.clip_len = clip_len
        self.frame_transform = frame_transform
        self.video_transform = video_transform

    def __iter__(self):
        for i in range(self.epoch_size):
            # Get random sample
            path, target = np.random.choice(self.samples)
            # Get video object
            vid = torchvision.io.VideoReader(path, "video")
            metadata = vid.get_metadata()
            video_frames = []  # video frame buffer

            # Seek and return frames
            max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
            start = np.random.uniform(0., max_seek)
            for frame in itertools.islice(vid.seek(start), self.clip_len):
                video_frames.append(self.frame_transform(frame['data']))
                current_pts = frame['pts']
            # Stack it into a tensor
            video = torch.stack(video_frames, 0)
            if self.video_transform:
                video = self.video_transform(video)
            output = {
                'path': path,
                'video': video,
                'target': target,
                'start': start,
                'end': current_pts}
            yield output

class VideoDataset(Dataset):
    def __init__(self, time_depth, mean, std, transform=None) -> None:
        super().__init__()

class RARPDataset(Dataset):
    def __init__(self, path_RARP_dataset:Path, path_RARP_Folds:Path=None, split=0, DataSetType:RARP_DataSetType="train", transform=None) -> None:
        super().__init__()
        self.samples = pd.read_csv(path_RARP_dataset, usecols=["id", "label", "class", "path"])
        self.classes = self.samples["class"].unique().tolist()
        self.targets = self.samples["label"].to_list()
        self.IDsplit = None
        self.transform = transform
                
        if path_RARP_Folds is not None:
            if split is None:
                raise Exception("Is required the split index to do Folds")
            
            arrFolds = np.load(path_RARP_Folds, allow_pickle=True)
            self.IDsplit = np.array_split(arrFolds, len(arrFolds)/3)[split][DataSetType.value]
            
            self.samples = self.samples.loc[self.samples["id"].isin(self.IDsplit)]
            self.targets = self.samples["label"]
            
    def __len__(self):
        return self.samples.shape[0]
    
    def __getitem__(self, index):
        ID = index
        if self.IDsplit is not None:
            ID = self.IDsplit[index]
            
        rs = self.samples.loc[self.samples["id"] == ID]
            
        img = defs.load_file_tensor(rs["path"].item())
        label = rs["label"].item()
        
        if self.transform is not None:
            img = self.transform(img)
            
        return img, label