Newer
Older
RARP / RARP_NVB2.py
@delAguila delAguila on 22 Nov 2024 18 KB init comit
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import torchmetrics
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from CustomCallback import StepDropout
import Loaders
import defs
import argparse
import seaborn as sn 
import Models as M
import pandas as pd
import warnings
from ultralytics import YOLO

torch.set_float32_matmul_precision('high')
torch.backends.cudnn.deterministic = True

def Calc_Eval_table_New(TrainModel:M.RARP_NVB_Model):
    TrainModel.to(device)
    TrainModel.eval()
    
    Predictions = []
    Labels = []

    if isinstance(TrainModel, M.RARP_NVB_Model_test2):
        with torch.no_grad():
            for img, label in tqdm(iter(Test_DataLoader)):
                img = img.float().to(device)
                label = label.to(device)
                
                pred = TrainModel(img)
                Predictions.append(torch.softmax(pred, dim=1))
                Labels.append(label)
            
        Predictions = torch.cat(Predictions)
        Labels = torch.cat(Labels)
        
        print(Predictions, Labels)
        
        acc = torchmetrics.Accuracy("multiclass", num_classes=2).to(device)(Predictions, Labels)
        precision = torchmetrics.Precision("multiclass", num_classes=2).to(device)(Predictions, Labels)
        recall = torchmetrics.Recall("multiclass", num_classes=2).to(device)(Predictions, Labels)
        auc = torchmetrics.AUROC("multiclass", num_classes=2).to(device)(Predictions, Labels)
        f1Score = torchmetrics.F1Score("multiclass", num_classes=2).to(device)(Predictions, Labels)

        return [acc.item(), precision.item(), recall.item(), f1Score.item(), auc.item()]

    with torch.no_grad():
        for img, label in tqdm(iter(Test_DataLoader)):
            img = img.float().to(device)
            label = label.float().to(device)
            
            pred = TrainModel(img)
            Predictions.append(torch.sigmoid(pred.squeeze(1)))
            Labels.append(label)
            
    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels)
        
    print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    #cm = torchmetrics.ConfusionMatrix('binary')(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)

    return [acc.item(), precision.item(), recall.item(), f1Score.item(), auc.item()]

def Calc_Eval_table(TrainModel:M.RARP_NVB_Model):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []

    with torch.no_grad():
        for data, label in tqdm(iter(Test_DataLoader)):
            data = data.float().to(device)
            label = label.to(device)
            pred = TrainModel(data).flatten()
            Predictions.append(torch.sigmoid(pred))
            Labels.append(label)
            

    Predictions = torch.cat(Predictions)
    Labels = torch.cat(Labels)

    print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
    precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
    recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
    auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)

    return [f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}"]

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
   
def CAM(model:M.RARP_NVB_Model, img:torch.Tensor, device):
    with torch.no_grad():
        img = img.to(device).float().unsqueeze(0)
        pred, feature = model(img)
    
    _, c, h, w = feature.shape
    feature = feature.reshape((c, h*w))
    if isinstance(model, (M.RARP_NVB_ResNet18_CAM, M.RARP_NVB_ResNet50_CAM)):
        wParams = list(model.model.fc.parameters())
    elif isinstance(model, M.RARP_NVB_MobileNetV2_CAM):
        wParams = list(model.model.classifier.parameters())
    pesos = wParams[0].detach()
    cam = torch.matmul(pesos, feature)

    cam = cam - torch.min(cam)
    cam_img =  cam / torch.max(cam)
    cam_img = cam_img.reshape(h, w).cpu()

    return cam_img, torch.sigmoid(pred)

def CAMVisualizer(img, heatmap, pred, label, mean, std, ax, row):
    img = img.numpy().transpose((1, 2, 0))
    heatmap = transforms.functional.resize(heatmap.unsqueeze(0), (img.shape[0], img.shape[1]), antialias=True)[0]

    img = np.clip((std * img + mean) / 255, 0, 1)
    img = img[...,::-1].copy()
    col = 0
    if row > 3:
        col = 2
        
    if row > 7:
        col = 4
        
    ax[row % 4][col + 0].imshow(img)
    ax[row % 4][col + 0].axis('off')
    
    ax[row % 4][col + 1].imshow(img)
    ax[row % 4][col + 1].imshow(heatmap, alpha=0.5, cmap="jet")
    ax[row % 4][col + 1].axis('off') 
    ax[row % 4][col + 1].set_title(f"Pred.: {pred.item():.4f}; Label: {label}")
    
    #plt.title()
    
def ShowCAM(TrainedModel:M.RARP_NVB_Model, mean, std, title=""):
    TrainedModel.to(device)
    TrainedModel.eval()
    i = 0
    params = {
        "left":0,
        "bottom":0.01,
        "right":1,
        "top":0.914,
        "wspace":0,
        "hspace":0.164
    }
    fig, axis = plt.subplots(4, 6, gridspec_kw=params)
    with torch.no_grad():
        if len(testDataset) > 12:
            ix = np.unique(testDataset.targets, return_counts=True)[1]
            NOTNVB_Indexs = np.asarray(range(ix[0]))
            NVB_Indexs = np.asarray(range(ix[0], ix[0] + ix[1]))
            np.random.shuffle(NOTNVB_Indexs)
            np.random.shuffle(NVB_Indexs)
            
            for j, index in enumerate(NOTNVB_Indexs):
                if j == 6:
                    break
                img, label = testDataset[index]
                cam, pred = CAM(TrainedModel, img, device)
                CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
                i += 1
                
            for j, index in enumerate(NVB_Indexs):
                if j == 5:
                    break
                img, label = testDataset[index]
                cam, pred = CAM(TrainedModel, img, device)
                CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
                i += 1
        else:        
            for img, label in tqdm(testDataset):
                cam, pred = CAM(TrainedModel, img, device)
                CAMVisualizer(img, cam, pred, label, mean, std, axis, i)
                i += 1
            
    fig.suptitle(title)
    
def Calc_Eval(TrainModel:M.RARP_NVB_Model):
    TrainModel.to(device)
    TrainModel.eval()

    Predictions = []
    Labels = []

    with torch.no_grad():
        for data, label in tqdm(testDataset):
            data = data.to(device).float().unsqueeze(0)
            pred = torch.sigmoid(TrainModel(data)[0].cpu())
            Predictions.append(pred)
            Labels.append(label)

    Predictions = torch.cat(Predictions)
    Labels = torch.tensor(Labels).int()

    print(Predictions, Labels)

    acc = torchmetrics.Accuracy('binary')(Predictions, Labels)
    precision = torchmetrics.Precision('binary')(Predictions, Labels)
    recall = torchmetrics.Recall('binary')(Predictions, Labels)
    cm = torchmetrics.ConfusionMatrix('binary')(Predictions, Labels)
    auc = torchmetrics.AUROC('binary')(Predictions, Labels)
    f1Score = torchmetrics.F1Score('binary')(Predictions, Labels)

    print(f"Val Accuracy: {acc:.4f}")
    print(f"Val Precision: {precision:.4f}")
    print(f"Val Recall: {recall:.4f}")
    print(f"F1 Score: {f1Score:.4f}")
    print(f"AUROC: {auc:.4f}")
    print(testDataset.classes)

    ax = sn.heatmap(cm, cmap="Greens", cbar=False, annot=True, annot_kws={"size": 18}, fmt='g', xticklabels=testDataset.classes, yticklabels=testDataset.classes)
    ax.set_title(f"NVB Classifier Split #{args.Fold+1}")  
    ax.set_xlabel('Predict')  
    ax.set_ylabel('True')  
    plt.show()
    

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
    parser.add_argument("--Fold", type=int, default=0)
    parser.add_argument("--Workers", type=int, default=0)
    parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
    parser.add_argument("--Model", type=int, default=0, help="0 = ResNet18, 1 = ResNet50")
    parser.add_argument("-lv", "--Log_version", type=int)
    parser.add_argument("-le", "--Log_epoch", type=int)
    parser.add_argument("-ls", "--Log_step", type=int)
    parser.add_argument("--Remove_Blackbar", type=bool, default=True)
    parser.add_argument("--BGR2RGB", type=bool, default=False)
    parser.add_argument("--CAM", type=bool, default=False)
    parser.add_argument("-roi", "--Use_ROI_Dataset", type=int, default=0)
    parser.add_argument("-s", "--imgSlice_pct", type=float, default=None)
    parser.add_argument("-ns", "--Num_Slices", type=int, default=4)
    parser.add_argument("-wl", "--Wloss",type=bool, default=False)
    parser.add_argument("--sClass",type=int, default=None)
    parser.add_argument("-tl", "--TypeLoss", type=int, default=0)
    parser.add_argument("-cs", "--ColorSpace", type=int, default=None)

    args = parser.parse_args()
    
    if args.CAM and args.Phase == "train":
        raise Exception("Clases Activation Clases only in eval o eval_all")
    
    match args.Use_ROI_Dataset:
        case 1:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Crop",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSetCrop",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 256  
        case 0:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_main",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 720
        case 3:
            YoloModel = YOLO(model="RARP_YoloV8_ROI.pt")
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_main",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSet_YOLO",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace,
                ROI_Yolo=YoloModel
            )
            cropSize = 256
        case 2:
            Dataset = Loaders.RARP_DatasetCreator(
                "./DataSet_Crop1",
                FoldSeed=505,
                createFile=True,
                SavePath="./DataSetCrop1",
                Fold=5,
                removeBlackBar=args.Remove_Blackbar,
                RGBGama=args.BGR2RGB,
                SegImage=args.imgSlice_pct,
                Num_Img_Slices=args.Num_Slices,
                SegmentClass=args.sClass,
                colorSpace=args.ColorSpace
            )
            cropSize = 256 
        
    Dataset.CreateFolds()
    
    setup_seed(2023)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    batchSize = 17 #17 #8, 32
    numWorkers = args.Workers
    checkPtCallback = ModelCheckpoint(monitor='val_acc', save_top_k=10, mode='max')
    
    traintransform = torch.nn.Sequential(
        transforms.Normalize(Dataset.mean, Dataset.std),
        transforms.Resize(cropSize, antialias=True),
        transforms.RandomHorizontalFlip(0.6),
        transforms.RandomAffine(
            degrees=(-5, 5), translate=(0, 0.05), scale=(0.9, 1.1), 
            fill=5
        ),
        transforms.RandomResizedCrop((224, 224), scale=(0.35, 1), antialias=True),
    ).to(device)

    valtransform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)

    testtransform =  torch.nn.Sequential(
        transforms.Resize(256, antialias=True),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)
    
    rootFile = Dataset.CVS_File.parent.parent/f"fold_{args.Fold}"
    
    trainDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/"train"),
        loader=defs.load_file_tensor,
        extensions="npy",
        transform=traintransform
    )
    
    valDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/"val"),
        loader=defs.load_file_tensor,
        extensions="npy",
        transform=valtransform
    )

    testDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/"test"),
        loader=defs.load_file_tensor,
        extensions="npy",
        transform=testtransform
    )
    
    Train_DataLoader = DataLoader(
        trainDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=True, 
        pin_memory=True
    )
    Val_DataLoader = DataLoader(
        valDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True
    )
    Test_DataLoader = DataLoader(
        testDataset, 
        batch_size=batchSize, 
        num_workers=numWorkers, 
        shuffle=False, 
        pin_memory=True
    )
    
    if args.CAM:
        testCAMDataset = torchvision.datasets.DatasetFolder(
            str (rootFile/"test"),
            loader=defs.load_file_tensor,
            extensions="npy",
            transform=torch.nn.Sequential(
                transforms.Resize(256, antialias=True),
                transforms.Normalize(Dataset.mean, Dataset.std)
            ).to(device)
        )
        
        TestCAM_DataLoader = DataLoader(
            testCAMDataset, 
            batch_size=batchSize, 
            num_workers=numWorkers, 
            shuffle=False, 
            pin_memory=True
        )

    print(f"Currtent Fold Splits {Dataset.Splits[args.Fold]}")
    print(f"Unique Values in sets")
    info = np.unique(trainDataset.targets, return_counts=True), np.unique(valDataset.targets, return_counts=True), np.unique(testDataset.targets, return_counts=True)
    print(info)
    
    neg = 0
    pos = 0
    for i in info:
        neg += i[1][0]
        pos += i[1][1]
        
    total = neg + pos
    
    InitWeight = torch.tensor([total/(neg), total/(pos)]).to(device) if args.Wloss else None
    TypeLoss = M.TypeLossFunction(args.TypeLoss)
    match args.Model:
        case 0:
            Model = M.RARP_NVB_ResNet50(InitWeight, TypeLoss)
            ModelCAM = M.RARP_NVB_ResNet50_CAM()
        case 1:
            Model = M.RARP_NVB_ResNet18(InitWeight, TypeLoss)
            ModelCAM = M.RARP_NVB_ResNet18_CAM()
        case 2:
            Model = M.RARP_NVB_MobileNetV2(InitWeight, TypeLoss)
            ModelCAM = M.RARP_NVB_MobileNetV2_CAM()
        case 3:
            Model = M.RARP_NVB_EfficientNetV2(InitWeight, TypeLoss)
            ModelCAM = M.RARP_NVB_EfficientNetV2_CAM()
        case 4:
            models = [
                M.RARP_NVB_ResNet50.load_from_checkpoint(Path("./log_ResNet50_X6/lightning_logs/version_14/checkpoints/epoch=23-step=96.ckpt"), strict=False),
                M.RARP_NVB_ResNet18.load_from_checkpoint(Path("./log_restnet18_X6/lightning_logs/version_14/checkpoints/epoch=22-step=92.ckpt"), strict=False),
                M.RARP_NVB_ResNet50.load_from_checkpoint(Path("./log_ResNet50_X6/lightning_logs/version_9/checkpoints/epoch=44-step=180.ckpt"), strict=False),
                M.RARP_NVB_ResNet18.load_from_checkpoint(Path("./log_restnet18_X6/lightning_logs/version_9/checkpoints/epoch=25-step=104.ckpt"), strict=False),
            ]
            Model = M.RARP_Ensemble(models, InitWeight, TypeLoss)
        case _:
            raise Exception("Model Not yet Implemented")
    
    NameModel = type(Model).__name__
    print(f"Model Used: {NameModel}")
    LogFileName = f"{args.Log_Name}" #-{NameModel}

    warnings.simplefilter("ignore")
    trainer = L.Trainer(
        accelerator='gpu', 
        devices=1, 
        logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
        log_every_n_steps=1, 
        #callbacks=checkPtCallback, 
        callbacks=[checkPtCallback, StepDropout(5,  base_drop_rate=0.2, gamma=0.05)], 
        max_epochs=50,
    )

    if args.Phase == "train":
        print("Train Phase")
        trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
        #trainer.callbacks
        trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")
    elif args.Phase == "eval_all":
        print("Evaluation Phase")
        rows = []
        pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
        for ckpFile in pathCkptFile.glob("*.ckpt"):
            print(ckpFile.name)
            temp = Calc_Eval_table(Model.load_from_checkpoint(ckpFile, strict=False))
            #temp = Calc_Eval_table_New(Model.load_from_checkpoint(ckpFile, strict=False))
            temp.append(ckpFile.name)
            rows.append(temp)
            if args.CAM:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    print("CAM")
                    ShowCAM(ModelCAM.load_from_checkpoint(ckpFile, strict=False), Dataset.mean, Dataset.std, ckpFile.name)
        
        df = pd.DataFrame(rows, columns=["Acc","Precision","Recall","F1","AUROC","CheckPoint"])        
        df.style.highlight_max(color="red", axis=0)
        print(df)
        plt.show()
    else:
        print("Evaluation Phase")
        trainLog = [args.Log_version, args.Log_epoch, args.Log_step] 
        pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{trainLog[0]}/checkpoints/epoch={trainLog[1]}-step={trainLog[2]}.ckpt")
        Calc_Eval(Model.load_from_checkpoint(pathCkptFile))
        if args.CAM:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                ShowCAM(ModelCAM.load_from_checkpoint(pathCkptFile, strict=False), Dataset.mean, Dataset.std, pathCkptFile.name)
                plt.show()