import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import van
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
import Loaders
import numpy as np
import torchmetrics
import defs
import optuna
from optuna.integration import PyTorchLightningPruningCallback
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
class RARP_NVB_Classification_Head(torch.nn.Module):
def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.activation = activation_fn
if len (layer) == 0:
self.head = torch.nn.Linear(in_features, out_features)
else:
temp_head = []
next_input = in_features
for num in layer:
temp_head.append(torch.nn.Linear(next_input, num))
temp_head.append(self.activation)
temp_head.append(torch.nn.Dropout(0.4))
next_input = num
temp_head[-1] = torch.nn.Dropout(0.2)
temp_head.append(torch.nn.Linear(next_input, out_features))
self.head = torch.nn.Sequential(*temp_head)
del temp_head
def forward(self, x):
return self.head(x)
class RARP_VAN(L.LightningModule):
def __init__(
self,
van_model:str,
lr:float = 1e-4,
clasiffier_layers = [],
lambda_L1:float = 1.31E-04
):
super().__init__()
self.lr = lr
self.lambda_L1 = lambda_L1
#self.save_hyperparameters(ignore=["van_model"])
self.train_acc = torchmetrics.Accuracy('binary')
self.val_acc = torchmetrics.Accuracy('binary')
self.test_acc = torchmetrics.Accuracy('binary')
self.f1ScoreTest = torchmetrics.F1Score('binary')
self.van_encoder = van.van_b2(pretrained=False, num_classes=0)
self.van_encoder.load_state_dict(torch.load(van_model))
self.image_emb = 512
self.clasiffier = RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))
self.lossFN_clasiffier = torch.nn.BCEWithLogitsLoss()
def _calc_L1(self, params):
l1 = 0
for p in params:
l1 += torch.sum(torch.abs(p))
return self.lambda_L1 * l1
def forward(self, data):
data = data.float()
img_features = self.van_encoder(data)
pred = self.clasiffier(img_features)
return pred
def _shared_step(self, batch, val_step:bool=False):
data, label = batch
label = label.float()
prediction = self(data)
prediction = prediction.flatten()
predicted_labels = torch.sigmoid(prediction)
loss = self.lossFN_clasiffier(prediction, label)
if not val_step:
loss += self._calc_L1(self.clasiffier.parameters()) if self.lambda_L1 is not None else 0
return loss, label, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, pred_labels = self._shared_step(batch, False)
self.log("train_loss", loss, on_epoch=True)
self.train_acc.update(pred_labels, true_labels)
self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, pred_labels = self._shared_step(batch, True)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.val_acc.update(pred_labels, true_labels)
self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
def test_step(self, batch, batch_idx):
_, true_labels, predicted_labels = self._shared_step(batch, True)
self.test_acc.update(predicted_labels, true_labels)
self.f1ScoreTest.update(predicted_labels, true_labels)
self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return [optimizer]
if __name__ == "__main__":
Dataset = Loaders.RARP_DatasetCreator(
"./DataSet_Ando_All_no20Crop",
FoldSeed=505,
createFile=True,
SavePath="./DataSet_AndoAll20_crop",
Fold=5,
removeBlackBar=False,
)
Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
Dataset.CreateFolds()
setup_seed(2023)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Fold = 0
InitResize=(512, 512)
numWorkers = 5
MaxEpochs = 100
LogFileName = "log_X21"
rootFile = Dataset.CVS_File.parent.parent/f"fold_{Fold}"
checkPtCallback = [
callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'),
callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
]
traintransform = torch.nn.Sequential(
transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomErasing(0.8, value="random"),
transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=5),
transforms.GaussianBlur(5),
transforms.RandomCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
valtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
testtransform = torch.nn.Sequential(
transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(Dataset.mean, Dataset.std)
).to(device)
trainDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"train"),
loader = defs.load_file_tensor,
extensions = "npy",
transform = traintransform
)
valDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"val"),
loader = defs.load_file_tensor,
extensions = "npy",
transform = valtransform
)
testDataset = torchvision.datasets.DatasetFolder(
str (rootFile/"test"),
loader=defs.load_file_tensor,
extensions="npy",
transform=testtransform
)
def objective(trial):
batch = trial.suggest_int("batch", 8, 32, step=2)
lr = trial.suggest_float("lr", 4.10e-6, 1e-3, log=True)
l1 = trial.suggest_float("l1", 1e-6, 1e-3, log=True)
Train_DataLoader = DataLoader(
trainDataset,
batch_size=batch,
num_workers=numWorkers,
shuffle=True,
drop_last=True,
pin_memory=True,
persistent_workers=numWorkers>0
)
Val_DataLoader = DataLoader(
valDataset,
batch_size=batch,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
Model = RARP_VAN("van_b2_teacher_98.pth", clasiffier_layers=[128, 8], lr=lr, lambda_L1=l1)
trainer = L.Trainer(
deterministic=True,
accelerator='gpu',
devices=1,
logger=CSVLogger(save_dir=f"./{LogFileName}", name="Tune"),
log_every_n_steps=5,
callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc"), callbk.EarlyStopping(monitor="val_loss", mode="min", patience=7)],
max_epochs=MaxEpochs,
)
hyperParams = dict(batch=batch, lr=lr, l1=l1)
trainer.logger.log_hyperparams(hyperParams)
print("Train Phase")
trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
return trainer.callback_metrics["val_acc"].item()
pruner = optuna.pruners.SuccessiveHalvingPruner()#MedianPruner()
sampler = optuna.samplers.GPSampler(seed=2023)
study = optuna.create_study(direction="maximize", pruner=pruner, sampler=sampler)
study.optimize(objective, n_trials=100)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(f" Paramas: ")
for key, val in trial.params.items():
print(f" {key}: {val}")
#trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")