Newer
Older
RARP / RARP_NVB_eval.py
@delAguila delAguila on 22 Nov 2024 10 KB init comit
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torchmetrics
import torchmetrics.classification
import numpy as np
import yaml
import Models as M
from pathlib import Path
import Loaders
import defs
import matplotlib.pyplot as plt
import seaborn as sn 
from ultralytics import YOLO
import argparse

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
def CalcAgreement(x:torch):
    deltas = torch.sum(torch.abs(x.unsqueeze(1) - x), 1)
    deltasMean = deltas.mean()
    
    factor = (deltas <= deltasMean).float()
    factor = factor / factor.count_nonzero()
    
    return torch.dot(factor, x)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--Fold", type=int, default=0)
    parser.add_argument("-m", "--Mode", type=int, default=0)
    
    args = parser.parse_args()
    
    Models = []

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print (f"Fold {args.Fold}")
    with open(f"eval-EFold{args.Fold}.yaml") as file:
        
        configFile = yaml.load(file, Loader=yaml.FullLoader)
        
    for models in configFile["models"]:
        match models:
            case "ResNet50_ckpt":
                if configFile["models"][models] is not None:
                    for i, pathckpt in enumerate(configFile["models"][models]):
                        if configFile["Fold_Ensamble"] or configFile['Fold_Num'] is None or configFile['Fold_Num'] == i:
                            Models.append(M.RARP_NVB_ResNet50.load_from_checkpoint(Path(pathckpt), strict=False).to(device).eval())
            case "ResNet18_ckpt":
                if configFile["models"][models] is not None:
                    for i, pathckpt in enumerate(configFile["models"][models]):
                        if configFile["Fold_Ensamble"] or configFile['Fold_Num'] is None or configFile['Fold_Num'] == i:
                            Models.append(M.RARP_NVB_ResNet18.load_from_checkpoint(Path(pathckpt), strict=False).to(device).eval())
            case "MovilNetV2_ckpt":
                if configFile["models"][models] is not None:
                    for i, pathckpt in enumerate(configFile["models"][models]):
                        if configFile["Fold_Ensamble"] or configFile['Fold_Num'] is None or configFile['Fold_Num'] == i:
                            Models.append(M.RARP_NVB_MobileNetV2.load_from_checkpoint(Path(pathckpt), strict=False).to(device).eval())
            case "EfficientNetV2_ckpt":
                if configFile["models"][models] is not None:
                    for i, pathckpt in enumerate(configFile["models"][models]):
                        if configFile["Fold_Ensamble"] or configFile['Fold_Num'] is None or configFile['Fold_Num'] == i:
                            Models.append(M.RARP_NVB_EfficientNetV2.load_from_checkpoint(Path(pathckpt), strict=False).to(device).eval())
            case _:
                raise Exception("Model Not yet Implemented")
        
    match configFile["dataset_type"]:        
        case "full_size":
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_main",
                FoldSeed=505,
                createFile=True,
                Fold=5,
                SavePath="./DataSet_Eval",
                removeBlackBar=True
            )
        case "Manual_ROI":
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Crop1",
                FoldSeed=505,
                createFile=True,
                Fold=5,
                SavePath="./DataSet_Crop1_Eval",
                removeBlackBar=False
            )
        case "Manual_ROIwD":
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Crop",
                FoldSeed=505,
                createFile=True,
                Fold=5,
                SavePath="./DataSet_Crop_Eval",
                removeBlackBar=False
            )
        case "YOLO_ROI":
            YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_main",
                FoldSeed=505,
                createFile=True,
                Fold=5,
                SavePath="./DataSet_YOLO_Eval",
                removeBlackBar=True,
                SegmentClass=1,
                SegImage=0.75,
                Num_Img_Slices=2
            )
        case "YOLO_ROI_BIGDataset":
            YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_big",
                FoldSeed=505,
                createFile=True,
                Fold=5,
                SavePath="./DataSet_YOLO_Eval_BD",
                removeBlackBar=True,
                SegmentClass=0,
                SegImage=0.75,
                Num_Img_Slices=1
            )
        case "small_balaced":
            YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_smallBalaced",
                FoldSeed=505,
                createFile=True,
                Fold=5,
                SavePath="./DataSet_YOLO_SB_Eval",
                removeBlackBar=True
            )
        case "small_Ando":
            YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Ando",
                FoldSeed=505,
                createFile=True,
                Fold=5,
                SavePath="./DataSet_YOLO_Ando_Eval",
                removeBlackBar=True
            )
        case "small_balaced_full_size":
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_smallBalaced",
                FoldSeed=505,
                createFile=True,
                Fold=5,
                SavePath="./DataSet_YOLO_SBFS_Eval",
                removeBlackBar=True
            )
            
    print("Evaluation Phase")
    if configFile['Fold_Num'] is not None:
        Dataset.CreateFolds()
        rootFile = Dataset.CVS_File.parent.parent/f"fold_{configFile['Fold_Num']}"
        if (configFile["dataset_type"] in ["YOLO_ROI", "small_balaced", "YOLO_ROI_BIGDataset", "small_Ando"]):
            Dataset.ExtractROI_YOLO(YoloModel, configFile["YOLO_Accuracy_min_ROI"])
    else:
        Dataset.CreateClases()
        rootFile = Dataset.CVS_File.parent.parent/"dataset"
        
    Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])

    setup_seed(2023)



    valtransform = torch.nn.Sequential(
        transforms.Resize((224, 224), antialias=True),
        #transforms.Resize(256, antialias=True),
        #transforms.CenterCrop(224),
        #transforms.RandomHorizontalFlip(0.7),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)

    valDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/("test" if configFile['Fold_Num'] is not None else "")),
        loader=defs.load_file_tensor,
        extensions="npy",
        transform=valtransform
    )

    Val_DataLoader = DataLoader(
        valDataset, 
        batch_size=16, 
        num_workers=0, 
        shuffle=True, 
        pin_memory=True
    )

            
    Predictions = []
    Labels = []
    res = []

    with torch.no_grad():
        for data, label in iter(Val_DataLoader):
            data = data.float().to(device)
            label = label.to(device)
            prob = [torch.sigmoid(m(data)) for m in Models]
            prob = torch.cat(prob, dim=1)
            print (prob, label)
            prob = torch.tensor([CalcAgreement(d.squeeze()) for d in prob.split(1, 0)]).to(device) if args.Mode == 1 else prob.mean(dim=1)
            #prob = 
            Predictions.append(prob)
            Labels.append(label)
            
    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels)

    print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
    cm = torchmetrics.ConfusionMatrix('binary').to(device)
    cm.update(Predictions, Labels)
    _, ax = cm.plot()
    ax.set_title(f"NVB Classifier (th=0.5)")
    ax.set_xticklabels(valDataset.classes)
    ax.set_yticklabels(valDataset.classes)
    ax.set_xlabel('Ground Truth')



    print(f"Val Accuracy: {acc:.4f}")
    print(f"Val Precision: {precision:.4f}")
    print(f"Val Recall: {recall:.4f}")
    print(f"F1 Score: {f1Score:.4f}")
    print(f"AUROC: {auc:.4f}")
    print(f"Val Specificity: {specificty:.4f}")
    print(valDataset.classes)

    #ax = sn.heatmap(cm.cpu(), cmap="Greens", cbar=False, annot=True, annot_kws={"size": 18}, fmt='g', xticklabels=valDataset.classes, yticklabels=valDataset.classes)
    #ax.set_title(f"NVB Classifier")  
    #ax.set_xlabel('Predict')  
    #ax.set_ylabel('True')  

    aucCurve = torchmetrics.ROC("binary").to(device)
    fpr, tpr, thhols = aucCurve(Predictions, Labels)
    index = torch.argmax(tpr - fpr)
    print(f"False-Positive Rate: {fpr}")
    print(f"True-Positive Rate: {tpr}")
    print(tpr-fpr)
    print(index)
    print(thhols)
    th1 = thhols[index].item() if configFile['Youden-Index'] == "ROC" else (recall + specificty - 1).item() 
    _, ax = aucCurve.plot()
    ax.plot([0,1], linestyle='--')
    #ax.plot(torch.max(tpr - fpr).cpu(), torch.max(tpr - fpr).cpu(), "bo", markersize=5)
    #ax.plot(th1, th1, "ro", markersize=5)
    ax.set_title(f"ROC (AUROC={auc:.4f})") 

    print(f"Metris ajusted new threshold {th1:.4f}")
    acc = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
    specificty = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
    cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
    cm2.update(Predictions, Labels)
    _, ax = cm2.plot()
    ax.set_title(f"NVB Classifier (th={th1:.4f})")
    ax.set_xticklabels(valDataset.classes)
    ax.set_yticklabels(valDataset.classes)
    ax.set_xlabel('Ground Truth')

    print(f"Val Accuracy: {acc:.4f}")
    print(f"Val Precision: {precision:.4f}")
    print(f"Val Recall: {recall:.4f}")
    print(f"F1 Score: {f1Score:.4f}")
    print(f"AUROC: {auc:.4f}")
    print(f"Val Specificity: {specificty:.4f}")


    plt.show()