import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import van
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger
import Loaders
import numpy as np
import torchmetrics
import defs
import argparse
from pathlib import Path
from tqdm import tqdm
import pandas as pd
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
class RARP_NVB_Classification_Head(torch.nn.Module):
def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.activation = activation_fn
if len (layer) == 0:
self.head = torch.nn.Linear(in_features, out_features)
else:
temp_head = []
next_input = in_features
for num in layer:
temp_head.append(torch.nn.Linear(next_input, num))
temp_head.append(self.activation)
temp_head.append(torch.nn.Dropout(0.4))
next_input = num
temp_head[-1] = torch.nn.Dropout(0.2)
temp_head.append(torch.nn.Linear(next_input, out_features))
self.head = torch.nn.Sequential(*temp_head)
del temp_head
def forward(self, x):
return self.head(x)
class RARP_VAN(L.LightningModule):
def __init__(
self,
van_model:str = "",
lr:float = 1e-4,
clasiffier_layers = [],
lambda_L1:float = 1.31E-04
):
super().__init__()
self.save_hyperparameters(ignore=["van_model"])
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
if len(van_model) == 0:
self.van_encoder = van.van_b2(pretrained=True, num_classes=0)
print("pre-train ImageNet")
else:
self.van_encoder = van.van_b2(pretrained=False, num_classes=0)
self.van_encoder.load_state_dict(torch.load(van_model))
self.image_emb = 512
self.clasiffier = RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))
self.lossFN_clasiffier = torch.nn.BCEWithLogitsLoss()
def _calc_L1(self, params):
l1 = 0
for p in params:
l1 += torch.sum(torch.abs(p))
return self.hparams.lambda_L1 * l1
def forward(self, data):
data = data.float()
img_features = self.van_encoder(data)
pred = self.clasiffier(img_features)
return pred
def _shared_step(self, batch, val_step:bool=False):
data, label = batch
label = label.float()
prediction = self(data)
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
loss = self.lossFN_clasiffier(prediction, label)
if not val_step:
loss += self._calc_L1(self.clasiffier.parameters()) if self.hparams.lambda_L1 is not None else 0
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, pred_labels = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(pred_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, pred_labels = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(pred_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return [optimizer]
class RARP_ConvNext(RARP_VAN):
def __init__(self, model = "", lr = 0.0001, clasiffier_layers=[], lambda_L1 = 0.000131):
super().__init__("", lr, clasiffier_layers, lambda_L1)
if len(model) == 0:
self.van_encoder = torchvision.models.convnext_small(weights=torchvision.models.ConvNeXt_Small_Weights.DEFAULT)
self.van_encoder.classifier[-1] = torch.nn.Identity()
else:
self.van_encoder = torchvision.models.convnext_small()
self.van_encoder.classifier[-1] = torch.nn.Identity()
self.van_encoder.load_state_dict(torch.load(model))
self.image_emb = 768
self.clasiffier = RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--Phase", default="train", type=str, help="'train' or 'eval'")
parser.add_argument("--Fold", type=int, default=0)
parser.add_argument("--Workers", type=int, default=5)
parser.add_argument("--Log_Name", type=str, default="logs_debug", help="the name of the directory of the log chkp")
parser.add_argument("-p", "--Pre_train", type=str, default="RARP")
parser.add_argument("-w", "--Weigth", type=str, default="")
parser.add_argument("-lv", "--Log_version", type=int)
parser.add_argument("-e", "--Encoder", type=str, default="VAN")
args = parser.parse_args()
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All_no20Crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_AndoAll20_crop",
Fold=5,
removeBlackBar=False,
)
Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
Dataset.CreateFolds()
setup_seed(2023)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Fold = args.Fold
InitResize=(512, 512)
batchSize = 8
numWorkers = args.Workers
MaxEpochs = 100
LogFileName = args.Log_Name
rootFile = Dataset.CVS_File.parent.parent/f"fold_{Fold}"
checkPtCallback = [
callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
]
traintransform = torch.nn.Sequential(
transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomRotation(
degrees=(-15, 15),
fill=5
),
transforms.RandomResizedCrop(
224,
scale=(0.4, 1),
antialias=True,
interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))
], 0.3),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
valtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
testtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
trainDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"train"),
loader = defs.load_file_tensor,
extensions = "npy",
transform = traintransform
)
valDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"val"),
loader = defs.load_file_tensor,
extensions = "npy",
transform = valtransform
)
testDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transform=testtransform
)
Train_DataLoader = DataLoader(
trainDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=True,
drop_last=True,
pin_memory=True,
persistent_workers=numWorkers>0
)
Val_DataLoader = DataLoader(
valDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
Test_DataLoader = DataLoader(
testDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
match(args.Phase):
case "train":
Model = None
match (args.Encoder):
case "VAN":
Model = RARP_VAN(args.Weigth if args.Pre_train == "RARP" else "", clasiffier_layers=[], lr=1e-4, lambda_L1=None)#lambda_L1=2.22e-6
case "ConvNext":
Model = RARP_ConvNext(args.Weigth if args.Pre_train == "RARP" else "", clasiffier_layers=[], lr=1e-4, lambda_L1=None)
trainer = L.Trainer(
deterministic=True,
accelerator='gpu',
devices=1,
logger=TensorBoardLogger(save_dir=f"./{LogFileName}"),
log_every_n_steps=5,
callbacks=checkPtCallback,
max_epochs=MaxEpochs,
)
print("Train Phase")
trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")
case "eval_all":
print("Evaluation Phase")
rows = []
pathCkptFile = Path(f"./{LogFileName}/lightning_logs/version_{args.Log_version}/checkpoints/")
for ckpFile in pathCkptFile.glob("*.ckpt"):
print(ckpFile.name)
Model = None
match (args.Encoder):
case "VAN":
Model = RARP_VAN.load_from_checkpoint(ckpFile)
case "ConvNext":
Model = RARP_ConvNext.load_from_checkpoint(ckpFile)
Model.to(device)
Model.eval()
Predictions = []
Labels = []
with torch.no_grad():
for data, label in tqdm(iter(Test_DataLoader)):
data = data.float().to(device)
label = label.to(device)
pred = Model(data).flatten()
Predictions.append(torch.sigmoid(pred))
Labels.append(label)
Predictions = torch.cat(Predictions)
Labels = torch.cat(Labels)
print(Predictions, Labels)
acc = torchmetrics.Accuracy('binary').to(device)(Predictions, Labels)
precision = torchmetrics.Precision('binary').to(device)(Predictions, Labels)
recall = torchmetrics.Recall('binary').to(device)(Predictions, Labels)
auc = torchmetrics.AUROC('binary').to(device)(Predictions, Labels)
f1Score = torchmetrics.F1Score('binary').to(device)(Predictions, Labels)
specificty = torchmetrics.Specificity("binary").to(device)(Predictions, Labels)
table = [
["0.5000", f"{acc.item():.4f}", f"{precision.item():.4f}", f"{recall.item():.4f}", f"{f1Score.item():.4f}", f"{auc.item():.4f}", f"{specificty.item():.4f}", ""]
]
for i in range(2):
aucCurve = torchmetrics.ROC("binary").to(device)
fpr, tpr, thhols = aucCurve(Predictions, Labels)
index = torch.argmax(tpr - fpr)
th2 = (recall + specificty - 1).item()
th2 = 0.5 if th2 <= 0 else th2
th1 = thhols[index].item() if i == 0 else th2
accY = torchmetrics.Accuracy('binary', threshold=th1).to(device)(Predictions, Labels)
precisionY = torchmetrics.Precision('binary', threshold=th1).to(device)(Predictions, Labels)
recallY = torchmetrics.Recall('binary', threshold=th1).to(device)(Predictions, Labels)
specifictyY = torchmetrics.Specificity("binary", threshold=th1).to(device)(Predictions, Labels)
f1ScoreY = torchmetrics.F1Score('binary', threshold=th1).to(device)(Predictions, Labels)
#cm2 = torchmetrics.ConfusionMatrix('binary', threshold=th1).to(device)
#cm2.update(Predictions, Labels)
#_, ax = cm2.plot()
#ax.set_title(f"NVB Classifier (th={th1:.4f})")
table.append([
f"{th1:.4f}",
f"{accY.item():.4f}",
f"{precisionY.item():.4f}",
f"{recallY.item():.4f}",
f"{f1ScoreY.item():.4f}",
f"{auc.item():.4f}",
f"{specifictyY.item():.4f}",
ckpFile.name
])
rows += table
df = pd.DataFrame(rows, columns=["Youden", "Acc","Precision","Recall","F1","AUROC","Specificity","CheckPoint"])
df.style.highlight_max(color="red", axis=0)
print(df)