import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torchmetrics
import torchmetrics.classification
import numpy as np
import yaml
import Models as M
from pathlib import Path
import Loaders
import defs
import matplotlib.pyplot as plt
import seaborn as sn
from ultralytics import YOLO
import argparse
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
def CalcAgreement(x:torch):
deltas = torch.sum(torch.abs(x.unsqueeze(1) - x), 1)
deltasMean = deltas.mean()
factor = (deltas <= deltasMean).float()
factor = factor / factor.count_nonzero()
return torch.dot(factor, x)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--Fold", type=int, default=0)
parser.add_argument("-m", "--Mode", type=int, default=0)
args = parser.parse_args()
Models = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print (f"Fold {args.Fold}")
with open(f"eval-EFold{args.Fold}.yaml") as file:
configFile = yaml.load(file, Loader=yaml.FullLoader)
for models in configFile["models"]:
match models:
case "ResNet50_ckpt":
if configFile["models"][models] is not None:
for i, pathckpt in enumerate(configFile["models"][models]):
if configFile["Fold_Ensamble"] or configFile['Fold_Num'] is None or configFile['Fold_Num'] == i:
Models.append(M.RARP_NVB_ResNet50.load_from_checkpoint(Path(pathckpt), strict=False).to(device).eval())
case "ResNet18_ckpt":
if configFile["models"][models] is not None:
for i, pathckpt in enumerate(configFile["models"][models]):
if configFile["Fold_Ensamble"] or configFile['Fold_Num'] is None or configFile['Fold_Num'] == i:
Models.append(M.RARP_NVB_ResNet18.load_from_checkpoint(Path(pathckpt), strict=False).to(device).eval())
case "MovilNetV2_ckpt":
if configFile["models"][models] is not None:
for i, pathckpt in enumerate(configFile["models"][models]):
if configFile["Fold_Ensamble"] or configFile['Fold_Num'] is None or configFile['Fold_Num'] == i:
Models.append(M.RARP_NVB_MobileNetV2.load_from_checkpoint(Path(pathckpt), strict=False).to(device).eval())
case "EfficientNetV2_ckpt":
if configFile["models"][models] is not None:
for i, pathckpt in enumerate(configFile["models"][models]):
if configFile["Fold_Ensamble"] or configFile['Fold_Num'] is None or configFile['Fold_Num'] == i:
Models.append(M.RARP_NVB_EfficientNetV2.load_from_checkpoint(Path(pathckpt), strict=False).to(device).eval())
case _:
raise Exception("Model Not yet Implemented")
match configFile["dataset_type"]:
case "full_size":
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_main",
FoldSeed=505,
createFile=True,
Fold=5,
SavePath="./DataSet_Eval",
removeBlackBar=True
)
case "Manual_ROI":
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Crop1",
FoldSeed=505,
createFile=True,
Fold=5,
SavePath="./DataSet_Crop1_Eval",
removeBlackBar=False
)
case "Manual_ROIwD":
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Crop",
FoldSeed=505,
createFile=True,
Fold=5,
SavePath="./DataSet_Crop_Eval",
removeBlackBar=False
)
case "YOLO_ROI":
YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_main",
FoldSeed=505,
createFile=True,
Fold=5,
SavePath="./DataSet_YOLO_Eval",
removeBlackBar=True,
SegmentClass=1,
SegImage=0.75,
Num_Img_Slices=2
)
case "YOLO_ROI_BIGDataset":
YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_big",
FoldSeed=505,
createFile=True,
Fold=5,
SavePath="./DataSet_YOLO_Eval_BD",
removeBlackBar=True,
SegmentClass=0,
SegImage=0.75,
Num_Img_Slices=1
)
case "small_balaced":
YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_smallBalaced",
FoldSeed=505,
createFile=True,
Fold=5,
SavePath="./DataSet_YOLO_SB_Eval",
removeBlackBar=True
)
case "small_Ando":
YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando",
FoldSeed=505,
createFile=True,
Fold=5,
SavePath="./DataSet_YOLO_Ando_Eval",
removeBlackBar=True
)
case "small_balaced_full_size":
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_smallBalaced",
FoldSeed=505,
createFile=True,
Fold=5,
SavePath="./DataSet_YOLO_SBFS_Eval",
removeBlackBar=True
)
print("Evaluation Phase")
if configFile['Fold_Num'] is not None:
Dataset.CreateFolds()
rootFile = Dataset.CVS_File.parent.parent/f"fold_{configFile['Fold_Num']}"
if (configFile["dataset_type"] in ["YOLO_ROI", "small_balaced", "YOLO_ROI_BIGDataset", "small_Ando"]):
Dataset.ExtractROI_YOLO(YoloModel, configFile["YOLO_Accuracy_min_ROI"])
else:
Dataset.CreateClases()
rootFile = Dataset.CVS_File.parent.parent/"dataset"
Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
setup_seed(2023)
valtransform = torch.nn.Sequential(
transforms.Resize((224, 224), antialias=True),
#transforms.Resize(256, antialias=True),
#transforms.CenterCrop(224),
#transforms.RandomHorizontalFlip(0.7),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
valDataset = torchvision.datasets.DatasetFolder(
str (rootFile/("test" if configFile['Fold_Num'] is not None else "")),
loader=defs.load_file_tensor,
extensions="npy",
transform=valtransform
)
Val_DataLoader = DataLoader(
valDataset,
batch_size=16,
num_workers=0,
shuffle=True,
pin_memory=True
)
Predictions = []
Labels = []
res = []
with torch.no_grad():
for data, label in iter(Val_DataLoader):
data = data.float().to(device)
label = label.to(device)
prob = [torch.sigmoid(m(data)) for m in Models]
prob = torch.cat(prob, dim=1)
print (prob, label)
prob = torch.tensor([CalcAgreement(d.squeeze()) for d in prob.split(1, 0)]).to(device) if args.Mode == 1 else prob.mean(dim=1)
#prob =
Predictions.append(prob)
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
cm = torchmetrics.ConfusionMatrix('binary').to(device)
cm.update(Predictions, Labels)
_, ax = cm.plot()
ax.set_title(f"NVB Classifier (th=0.5)")
ax.set_xticklabels(valDataset.classes)
ax.set_yticklabels(valDataset.classes)
ax.set_xlabel('Ground Truth')
print(f"Val Accuracy: {acc:.4f}")
print(f"Val Precision: {precision:.4f}")
print(f"Val Recall: {recall:.4f}")
print(f"F1 Score: {f1Score:.4f}")
print(f"AUROC: {auc:.4f}")
print(f"Val Specificity: {specificty:.4f}")
print(valDataset.classes)
#ax = sn.heatmap(cm.cpu(), cmap="Greens", cbar=False, annot=True, annot_kws={"size": 18}, fmt='g', xticklabels=valDataset.classes, yticklabels=valDataset.classes)
#ax.set_title(f"NVB Classifier")
#ax.set_xlabel('Predict')
#ax.set_ylabel('True')
aucCurve = torchmetrics.ROC("binary").to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
print(f"False-Positive Rate: {fpr}")
print(f"True-Positive Rate: {tpr}")
print(tpr-fpr)
print(index)
print(thhols)
th1 = thhols[index].item() if configFile['Youden-Index'] == "ROC" else (recall + specificty - 1).item()
_, ax = aucCurve.plot()
ax.plot([0,1], linestyle='--')
#ax.plot(torch.max(tpr - fpr).cpu(), torch.max(tpr - fpr).cpu(), "bo", markersize=5)
#ax.plot(th1, th1, "ro", markersize=5)
ax.set_title(f"ROC (AUROC={auc:.4f})")
print(f"Metris ajusted new threshold {th1:.4f}")
acc = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
cm2.update(Predictions, Labels)
_, ax = cm2.plot()
ax.set_title(f"NVB Classifier (th={th1:.4f})")
ax.set_xticklabels(valDataset.classes)
ax.set_yticklabels(valDataset.classes)
ax.set_xlabel('Ground Truth')
print(f"Val Accuracy: {acc:.4f}")
print(f"Val Precision: {precision:.4f}")
print(f"Val Recall: {recall:.4f}")
print(f"F1 Score: {f1Score:.4f}")
print(f"AUROC: {auc:.4f}")
print(f"Val Specificity: {specificty:.4f}")
plt.show()