diff --git a/Loaders.py b/Loaders.py index af01cce..cbba50e 100644 --- a/Loaders.py +++ b/Loaders.py @@ -18,6 +18,16 @@ from ultralytics import YOLO import itertools +class NVB_Classes(Enum): + NOT_NVB = 0 + R_NVB = 1 + L_NVB = 2 + RL_NVB = 3 + +class NVB_Binary(Enum): + NOT_NVB = 0 + NVB = 1 + def _find_classes(dir): classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes.sort() @@ -83,28 +93,34 @@ return listROIs - def __init__(self, - RootPath = "", - extension="tiff", - FoldSeed=505, - createFile=False, - SavePath="", - Fold:int=None, - preResize=None, - RGBGama=False, - removeBlackBar=True, - SegImage:float=None, - Num_Img_Slices:int = 4, - SegmentClass:int = None, - colorSpace:int=None, - ROI_Yolo:YOLO=None, - thresholdYolo_Accuracy:float=0.75) -> None: + def __init__( + self, + RootPath = "", + extension="tiff", + FoldSeed=505, + createFile=False, + SavePath="", + Fold:int=None, + preResize=None, + RGBGama=False, + removeBlackBar=True, + SegImage:float=None, + Num_Img_Slices:int = 4, + SegmentClass:int = None, + colorSpace:int=None, + ROI_Yolo:YOLO=None, + thresholdYolo_Accuracy:float=0.75, + Num_Labels:int = None + ) -> None: root = Path(RootPath) lMean = [] lStd = [] NO_NVB = [] NVB = [] + if Num_Labels is not None: + MultiClasses = [[] for _ in range(Num_Labels)] + self.Num_Labels = Num_Labels if createFile: if len(SavePath) == 0: @@ -137,7 +153,8 @@ tempImgList = self.ROI_Extract_YOLO(ROI_Yolo, tempImg, thresholdYolo_Accuracy) for k, tempImg in enumerate(tempImgList): - lineaCSV = [id, 0 if file.parent.name == "NOT_NVB" else 1, file.parent.name, (dumpImgs/f"Img{k}-{i}.npy").absolute(), file.name] + idClass = NVB_Binary[file.parent.name].value if Num_Labels is None else NVB_Classes[file.parent.name].value + lineaCSV = [id, idClass, file.parent.name, (dumpImgs/f"Img{k}-{i}.npy").absolute(), file.name] lMean.append(np.mean(tempImg, axis=tuple(range(tempImg.ndim-1)))) lStd.append(np.std(tempImg, axis=tuple(range(tempImg.ndim-1)))) lineaCSV += np.mean(tempImg, axis=tuple(range(tempImg.ndim-1))).tolist() @@ -145,10 +162,13 @@ writerOBJ.writerow(lineaCSV) #lista.append (np.mean(tempImg, axis=tuple(range(tempImg.ndim-1)))) np.save(dumpImgs/f"Img{k}-{i}.npy", tempImg) - if lineaCSV[1] == 0: - NO_NVB.append(id) + if Num_Labels is None: + if lineaCSV[1] == 0: + NO_NVB.append(id) + else: + NVB.append(id) else: - NVB.append(id) + MultiClasses[idClass].append(id) id += 1 @@ -159,14 +179,18 @@ break imgPath = dumpImgs/f"Img{k}-{i}-{j}.npy" np.save(imgPath, newImg) - lineaCSV = [id, 0 if file.parent.name == "NOT_NVB" else 1, file.parent.name, (dumpImgs/f"Img{k}-{i}-{j}.npy").absolute(), file.name] + idClass = NVB_Binary[file.parent.name].value if Num_Labels is None else NVB_Classes[file.parent.name].value + lineaCSV = [id, idClass, file.parent.name, (dumpImgs/f"Img{k}-{i}-{j}.npy").absolute(), file.name] lineaCSV += np.mean(newImg, axis=tuple(range(newImg.ndim-1))).tolist() lineaCSV += np.std(newImg, axis=tuple(range(newImg.ndim-1))).tolist() writerOBJ.writerow(lineaCSV) - if lineaCSV[1] == 0: - NO_NVB.append(id) + if Num_Labels is None: + if lineaCSV[1] == 0: + NO_NVB.append(id) + else: + NVB.append(id) else: - NVB.append(id) + MultiClasses[idClass].append(id) id += 1 @@ -180,9 +204,12 @@ data = pd.read_csv(self.CVS_File) self.mean = data[["mean_1", "mean_2", "mean_3"]].mean().to_list() self.std = data[["std_1", "std_2", "std_3"]].mean().to_list() - - NO_NVB = data.loc[data["label"] == 0]["id"].to_list() - NVB = data.loc[data["label"] == 1]["id"].to_list() + if Num_Labels is None: + NO_NVB = data.loc[data["label"] == 0]["id"].to_list() + NVB = data.loc[data["label"] == 1]["id"].to_list() + else: + for i in range(Num_Labels): + MultiClasses[i] = data.loc[data["label"] == i]["id"].to_list() if Fold is not None: self.Splits = [] @@ -194,11 +221,19 @@ if FoldSeed is not None: np.random.seed(FoldSeed) - np.random.shuffle(NO_NVB) - np.random.shuffle(NVB) - - NO_NVB_Folds = list (self._split(NO_NVB, Fold)) - NVB_Fols = list(self._split(NVB, Fold)) + if Num_Labels is None: + np.random.shuffle(NO_NVB) + np.random.shuffle(NVB) + else: + for i in range(Num_Labels): + np.random.shuffle(MultiClasses[i]) + if Num_Labels is None: + NO_NVB_Folds = list (self._split(NO_NVB, Fold)) + NVB_Fols = list(self._split(NVB, Fold)) + else: + MultiClasses_Folds = [] + for i in range(Num_Labels): + MultiClasses_Folds.append(list(self._split(MultiClasses[i], Fold))) setst = [ math.trunc(0.60 * Fold), math.trunc(0.20 * Fold), math.trunc(0.20 * Fold) @@ -209,7 +244,12 @@ for s in setst: tempArry = [] for fold in data[ultimo: ultimo + s]: - tempArry += NO_NVB_Folds[fold] + NVB_Fols[fold] + if Num_Labels is None: + tempArry += NO_NVB_Folds[fold] + NVB_Fols[fold] + else: + for i in range(Num_Labels): + tempArry += MultiClasses_Folds[i][fold] + splitsToSave.append(tempArry) ultimo += s self.Folds_File = dumpImgs/"Folds.npy" @@ -261,7 +301,8 @@ SubfoldPath = foldPath/f"{RARP_DataSetType(datasetType).name}" for _, row in database.loc[database["id"].isin(subSet)].iterrows(): PathOri = Path(row["path"]) - PathImg = SubfoldPath/("NO_NVB" if row["label"] == 0 else "NVB")/PathOri.name + folderName = "NO_NVB" if row["label"] == 0 else "NVB" if self.Num_Labels is None else NVB_Classes(row["label"]).name + PathImg = SubfoldPath/folderName/PathOri.name PathImg.parent.mkdir(parents=True, exist_ok=True) shutil.copy(PathOri, PathImg) diff --git a/Models.py b/Models.py index c3c7bf1..efa71c9 100644 --- a/Models.py +++ b/Models.py @@ -8,6 +8,30 @@ import timm import van import numpy as np +from softadapt import LossWeightedSoftAdapt + + +class SoftAdapt_RARP: + def __init__(self, beta=0.1): + self.beta = beta + self.loss_history = [] + self.update_history = True + + def update_weights(self, losses): + losses = [l.detach() for l in losses] + if len(self.loss_history) > 0: + prev_losses = torch.stack(self.loss_history[-1]) + current_losses = torch.stack(losses) + delta_losses = current_losses - prev_losses + weights = torch.exp(self.beta * delta_losses) + weights /= weights.sum() + else: + weights = torch.ones(len(losses)) / len(losses) + + if self.update_history: + self.loss_history.append(losses) + + return weights def js_divergence_sigmoid(p_logits, q_logits): @@ -408,28 +432,58 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return torchvision.ops.focal_loss.sigmoid_focal_loss(input, target, self.alpha, self.gamma, self.reduction) +class FocalLoss(torch.nn.Module): + def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, inputs, targets): + # Convert targets to one-hot encoding + targets = torch.nn.functional.one_hot(targets, num_classes=inputs.size(1)).float() + + # Compute softmax over the inputs + probs = torch.nn.functional.softmax(inputs, dim=1) + log_probs = torch.nn.functional.log_softmax(inputs, dim=1) + + # Compute the focal loss components + focal_weight = (1 - probs) ** self.gamma + loss = -self.alpha * focal_weight * targets * log_probs + + # Apply reduction method + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: + return loss #TODO class RARP_NVB_MultiClassModel(L.LightningModule): def __init__(self, InitWeight = torch.tensor([1,1]), schedulerLR:bool=False, lr:float = 1e-4, - Model:torch.nn.Module = None + Model:torch.nn.Module = None, + Num_Classes:int = 2, + L1:float = 1.31E-04, + L2:float = 0 ) -> None: super().__init__() self.model = Model - self.lossFN = torch.nn.CrossEntropyLoss() + self.lossFN = FocalLoss() #torch.nn.CrossEntropyLoss() self.InitWeight = InitWeight self.scheduler = schedulerLR self.lr = lr - self.Lambda_L1 = 1.31E-04 - self.Lambda_L2 = 0 + self.Lambda_L1 = L1 + self.Lambda_L2 = L2 + self.num_classes = Num_Classes - self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=2) - self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=2) - self.test_acc = torchmetrics.Accuracy("multiclass", num_classes=2) - self.f1ScoreTest = torchmetrics.F1Score("multiclass", num_classes=2) + self.train_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes) + self.val_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes) + self.test_acc = torchmetrics.Accuracy("multiclass", num_classes=Num_Classes) + self.f1ScoreTest = torchmetrics.F1Score("multiclass", num_classes=Num_Classes) def forward(self, data): data = data.float() @@ -438,8 +492,7 @@ def _shared_step(self, batch): img, label = batch - #label = label.float() - prediction = self(img) #.flatten() + prediction = self(img) predicted_labels = torch.softmax(prediction, dim=1) loss = self.lossFN(prediction, label) @@ -1170,6 +1223,288 @@ for parms in self.teacher_Features.parameters(): parms.requires_grad = False +class RARP_NVB_DINO_MultiTask_v2(L.LightningModule): + # Define a hook function to capture the output + def _hook_fn_Student(self, module, input, output): + self.last_conv_output_S = output + + def _hook_fn_Teacher(self, module, input, output): + self.last_conv_output_T = output + + def __init__( + self, + TypeLoss=TypeLossFunction.CrossEntropy, + momentum_teacher:float = 0.9995, + lr:float = 1e-4, + L1:float = None, + L2:float = 0, + std: float = None, + mean: float = None, + ) -> None: + super().__init__() + + self.std_IMG = torch.tensor(std).view(3, 1, 1) if std is not None else None + self.mean_IMG = torch.tensor(mean).view(3, 1, 1) if mean is not None else None + + self.lr = lr + self.Lambda_L1 = L1 + self.Lambda_L2 = L2 + self.momentum_teacher = momentum_teacher + self.out_dim = 1024 + self.in_dim = 512 + self.weights = torch.tensor([1,1,1,1]) + + self.softAdapt = LossWeightedSoftAdapt(0.1) + self.loss_history = { + 'loss_DINO': [], + 'loss_Reconstruction': [], + 'loss_Binary': [], + 'loss_Multiclass': [], + } + self.update_interval = 2 + + self.train_acc = torchmetrics.Accuracy('binary') + self.val_acc = torchmetrics.Accuracy('binary') + self.test_acc = torchmetrics.Accuracy('binary') + #self.f1ScoreTest = torchmetrics.F1Score('binary') + + self.train_acc_Multi = torchmetrics.Accuracy('multiclass', num_classes=4) + self.val_acc_Multi = torchmetrics.Accuracy('multiclass', num_classes=4) + self.test_acc_Multi = torchmetrics.Accuracy('multiclass', num_classes=4) + + self.student = van.van_b2(pretrained = True, num_classes = 0) + self.teacher_Features = van.van_b2(pretrained = True, num_classes = 0) + + self.decoder = DynamicDecoder(input_channels=1024) #Decoder(input_channels=1024, hidden_channels=[512, 256, 128, 64]) + + self.student = RARP_NVB_DINO_Wrapper( + self.student, + RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2) + ) + + self.teacher_Features = RARP_NVB_DINO_Wrapper( + self.teacher_Features, + RARP_NVB_MLP(self.in_dim, self.out_dim, n_layers=2) + ) + + self.teacher_Features.load_state_dict(self.student.state_dict()) + for parms in self.teacher_Features.parameters(): + parms.requires_grad = False + + self.lossFN_DINO = RARP_NVB_DINO_Loss(self.out_dim, 0.04, 0.1, momentum_teacher) + self.lossFN = torch.nn.BCEWithLogitsLoss() + self.ReconstructionLoss = torch.nn.MSELoss() + self.lossFN_2 = torch.nn.CrossEntropyLoss(label_smoothing=0.5) + + self.last_conv_output_T = None + self.last_conv_output_S = None + + self.teacher_Features.backbone.block4[-1].register_forward_hook(self._hook_fn_Teacher) + self.student.backbone.block4[-1].register_forward_hook(self._hook_fn_Student) + + self.clasiffier = torch.nn.Sequential( + torch.nn.Linear(self.out_dim, 128), + torch.nn.Dropout(0.4), + torch.nn.SiLU(True), + + torch.nn.Linear(128, 8), + torch.nn.Dropout(0.2), + torch.nn.SiLU(True), + + torch.nn.Linear(8, 1) + ) + + self.clasiffier_RL = torch.nn.Sequential( + torch.nn.Linear(self.out_dim, 128), + torch.nn.Dropout(0.4), + torch.nn.SiLU(True), + + torch.nn.Linear(128, 8), + torch.nn.Dropout(0.4), + torch.nn.SiLU(True), + + torch.nn.Linear(8, 4) + ) + + print(f"lr={self.lr}, L1={self.Lambda_L1}") + + def _denormalize(self, tensor:torch.Tensor): + # Move mean and std to the same device as the input tensor + mean = self.mean_IMG.to(tensor.device) + std = self.std_IMG.to(tensor.device) + return tensor * std + mean + + def forward(self, data, val_step:bool = True): + if val_step: + data = data.float() + dataTeacher, dataStudent = data, data + else: + data = [d.float() for d in data] + dataTeacher, dataStudent = data[1:3], data + + TeacherDino = self.teacher_Features(dataTeacher) + Student = self.student(dataStudent) + + if not val_step: + NumChunks = len(dataStudent) + S_GlogalViews = self.last_conv_output_S.chunk(NumChunks)[1:3] + self.last_conv_output_S = torch.cat(S_GlogalViews, dim=0) + + cat_features = torch.cat((self.last_conv_output_S, self.last_conv_output_T), dim=1) + + reconstructed_image = self.decoder(cat_features) + + Cont_Net = torch.nn.functional.adaptive_avg_pool2d(cat_features, (1,1)).flatten(1) + pred = self.clasiffier(Cont_Net) + pred_RL = self.clasiffier_RL(Cont_Net) + + return (pred, pred_RL), (Student, TeacherDino), reconstructed_image + + def _shared_step(self, batch, val_step:bool = False): + img, label = batch + + + predictions, features, new_image = self(img, val_step) + StudentF, TeacherF = features + + prediction_preserv = predictions[0].flatten() + predicted_labels = torch.sigmoid(prediction_preserv) + + prediction_class = torch.softmax(predictions[1], dim=1) + + orignalImg = torch.cat([img[0].float() for _ in range(len(TeacherF))], dim=0) if not val_step else img.float() + label = torch.cat([label for _ in range(len(TeacherF))], dim=0) if not val_step else label + + label_preserv = (label != 1).float() + + #DINO Loss + loss_Dino = self.lossFN_DINO(StudentF, TeacherF) if not val_step else torch.tensor(0, device=label.device, dtype=torch.float32) + #Clasificator + loss_HL = self.lossFN(prediction_preserv, label_preserv) + loss_HL_multiClass = self.lossFN_2(prediction_class, label) + #Reconstruction + loss_img = self.ReconstructionLoss(new_image, orignalImg) + loss_img = loss_img.float() + + if self.Lambda_L1 is not None: + loss_l1 = 0 + for params in self.clasiffier.parameters(): # aqui + loss_l1 += torch.sum(torch.abs(params)) + loss_HL += self.Lambda_L1 * loss_l1 + + loss_l1 = 0 + for params in self.clasiffier_RL.parameters(): # aqui + loss_l1 += torch.sum(torch.abs(params)) + loss_HL_multiClass += self.Lambda_L1 * loss_l1 + + if self.Lambda_L2 > 0: + l2_reg = 0.0 + for param in self.clasiffier.parameters(): + l2_reg += torch.norm(param, 2) ** 2 + loss_HL += self.Lambda_L2 * l2_reg + + if not val_step: + self.loss_history["loss_DINO"].append(loss_Dino.item()) + self.loss_history["loss_Reconstruction"].append(loss_img.item()) + self.loss_history["loss_Binary"].append(loss_HL.item()) + self.loss_history["loss_Multiclass"].append(loss_HL_multiClass.item()) + + if len(self.loss_history["loss_DINO"]) == 5: #(self.current_epoch + 1) % self.update_interval == 0: + self.weights = self.softAdapt.get_component_weights( + torch.tensor(self.loss_history["loss_DINO"]), + torch.tensor(self.loss_history["loss_Reconstruction"]), + torch.tensor(self.loss_history["loss_Binary"]), + torch.tensor(self.loss_history["loss_Multiclass"]), + verbose=False + ) + + self.loss_history = { + 'loss_DINO': [], + 'loss_Reconstruction': [], + 'loss_Binary': [], + 'loss_Multiclass': [], + } + + loss = self.weights[0] * loss_Dino + self.weights[1] * loss_img + self.weights[2] * loss_HL + self.weights[3] * loss_HL_multiClass + + return loss, label, (predicted_labels, prediction_class), (self.weights[0] * loss_Dino, self.weights[2] * loss_HL, self.weights[1] * loss_img, new_image, self.weights[3] * loss_HL_multiClass) + + def training_step(self, batch, batch_idx): + self.softAdapt.update_history = True + loss, true_labels, predicted_labels, losses = self._shared_step(batch, False) + + boolLabels = (true_labels != 1).float() + + self.train_acc.update(predicted_labels[0], boolLabels) + self.train_acc_Multi.update(predicted_labels[1], true_labels) + + self.log("train_acc", self.train_acc, on_epoch=True, on_step=False) + self.log("train_acc_Multi", self.train_acc_Multi, on_epoch=True, on_step=False) + + self.log("train_loss", loss, on_epoch=True) + self.log("train_loss_Class", losses[4], on_epoch=True, on_step=False) + self.log("train_loss_img", losses[2], on_epoch=True, on_step=False) + self.log("train_loss_DINO", losses[0], on_epoch=True, on_step=False) + self.log("train_loss_GT", losses[1], on_epoch=True, on_step=False) + + if batch_idx % 50 == 0 and self.mean_IMG is not None and self.std_IMG is not None: + imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1) + imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :] + grid = torchvision.utils.make_grid(imgReconstruction) + self.logger.experiment.add_image('reconstructed_images', grid, self.global_step) + + return loss + + def on_train_batch_end(self, outputs, batch, batch_idx): + with torch.no_grad(): + for student_ps, teacher_ps in zip(self.student.parameters(), self.teacher_Features.parameters()): + teacher_ps.data.mul_(self.momentum_teacher) + teacher_ps.data.add_((1-self.momentum_teacher) * student_ps.detach().data) + + #self.logger.experiment.add_histogram ("Teacher_Center", self.lossFN_DINO.center) + + + def validation_step(self, batch, batch_idx): + self.softAdapt.update_history = False + loss, true_labels, predicted_labels, losses = self._shared_step(batch, True) + + boolLabels = (true_labels != 1).float() + + self.val_acc.update(predicted_labels[0], boolLabels) + self.val_acc_Multi.update(predicted_labels[1], true_labels) + + self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True) + self.log("val_acc_Multi", self.val_acc_Multi, on_epoch=True, on_step=False, prog_bar=True) + + self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True) + self.log("val_loss_Class", losses[4], on_epoch=True, on_step=False) + self.log("val_loss_img", losses[2], on_epoch=True, on_step=False) + self.log("val_loss_DINO", losses[0], on_epoch=True, on_step=False) + self.log("val_loss_GT", losses[1], on_epoch=True, on_step=False) + + def test_step(self, batch, batch_idx): + self.softAdapt.update_history = False + _, true_labels, predicted_labels, losses = self._shared_step(batch, True) + + boolLabels = (true_labels != 1).float() + + self.test_acc.update(predicted_labels[0], boolLabels) + self.test_acc_Multi.update(predicted_labels[1], true_labels) + + self.log("test_acc", self.test_acc, on_epoch=True, on_step=False) + self.log("test_acc_Multi", self.test_acc_Multi, on_epoch=True, on_step=False) + + if self.mean_IMG is not None and self.std_IMG is not None: + imgReconstruction = torch.clip(self._denormalize(losses[3]) / 255, 0, 1) + imgReconstruction = imgReconstruction[:, [2, 1, 0], :, :] + grid = torchvision.utils.make_grid(imgReconstruction) + self.logger.experiment.add_image('reconstructed_images_test', grid, self.global_step) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) #, weight_decay=self.Lambda_L2 + + return [optimizer] + class RARP_NVB_DINO_MultiTask(L.LightningModule): # Define a hook function to capture the output def _hook_fn_Student(self, module, input, output): diff --git a/RARP_NVB.py b/RARP_NVB.py index 11e9fab..b4d8406 100644 --- a/RARP_NVB.py +++ b/RARP_NVB.py @@ -522,9 +522,13 @@ ModelCAM = None if Ckpt_File is None else M.RARP_NVB_VAN_CAM.load_from_checkpoint(ckpFile, strict=False) case 14: TestModel = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) - tempFC_ft = TestModel.fc.in_features - TestModel.fc = torch.nn.Linear(in_features=tempFC_ft, out_features=2) - Model = M.RARP_NVB_MultiClassModel(None, Model=TestModel) if Ckpt_File is None else M.RARP_NVB_MultiClassModel.load_from_checkpoint(ckpFile,Model=TestModel) + TestModel.fc = torch.nn.Linear(in_features=TestModel.fc.in_features, out_features=4) + Model = M.RARP_NVB_MultiClassModel( + None, + Model=TestModel, + Num_Classes=4, + L1=None + ) if Ckpt_File is None else M.RARP_NVB_MultiClassModel.load_from_checkpoint(ckpFile,Model=TestModel) ModelCAM = None case 15: #if OptConfig.get("lr") is None: @@ -579,7 +583,7 @@ ) if Ckpt_File is None else M.RARP_NVB_RN50_VAN_V2.load_from_checkpoint(ckpFile) ModelCAM = None case 20: - Model = M.RARP_NVB_DINO_MultiTask( + Model = M.RARP_NVB_DINO_MultiTask_v2( TypeLoss, std=std, mean=mean, @@ -933,6 +937,22 @@ ROI_Yolo=YoloModel ) cropSize = 256 + case 20: + Dataset = Loaders.RARP_DatasetCreator( + "./DataSet_big_Multiclass", + FoldSeed=505, + createFile=True, + SavePath="./DataSet_Multiclass", + Fold=5, + Num_Labels=4, + removeBlackBar=args.Remove_Blackbar, + RGBGama=args.BGR2RGB, + SegImage=args.imgSlice_pct, + Num_Img_Slices=args.Num_Slices, + SegmentClass=args.sClass, + colorSpace=args.ColorSpace + ) + cropSize = 720 Dataset.CreateFolds()