Newer
Older
RARP / VAN_fine_Optuna.py
@delAguila delAguila 27 days ago 9 KB Final Commit.
import os
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import van
import lightning as L
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
import Loaders
import numpy as np
import torchmetrics
import defs

import optuna
from optuna.integration import PyTorchLightningPruningCallback

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True
    
class RARP_NVB_Classification_Head(torch.nn.Module):
    def __init__(self, in_features:int, out_features:int, layer:list=[], activation_fn:torch.nn.Module = torch.nn.ReLU(), *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.activation = activation_fn
        
        if len (layer) == 0:        
            self.head = torch.nn.Linear(in_features, out_features)
        else:
            temp_head = []
            next_input = in_features
            for num in layer:
                temp_head.append(torch.nn.Linear(next_input, num))
                temp_head.append(self.activation)
                temp_head.append(torch.nn.Dropout(0.4))
                next_input = num
                
            temp_head[-1] = torch.nn.Dropout(0.2)
            temp_head.append(torch.nn.Linear(next_input, out_features))
            
            self.head = torch.nn.Sequential(*temp_head)
            del temp_head
    
    def forward(self, x):
        return self.head(x)
        
class RARP_VAN(L.LightningModule):
    def __init__(
        self, 
        van_model:str,
        lr:float = 1e-4,
        clasiffier_layers = [],
        lambda_L1:float = 1.31E-04
    ):
        super().__init__()
        
        self.lr = lr
        self.lambda_L1 = lambda_L1
        
        #self.save_hyperparameters(ignore=["van_model"])
        
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        self.test_acc = torchmetrics.Accuracy('binary')
        self.f1ScoreTest = torchmetrics.F1Score('binary')
                
        self.van_encoder = van.van_b2(pretrained=False, num_classes=0)
        self.van_encoder.load_state_dict(torch.load(van_model))
        self.image_emb = 512
                       
        self.clasiffier = RARP_NVB_Classification_Head(self.image_emb, 1, clasiffier_layers, torch.nn.SiLU(True))
        
        self.lossFN_clasiffier = torch.nn.BCEWithLogitsLoss() 
        
    def _calc_L1(self, params):
        l1 = 0
        for p in params:
            l1 += torch.sum(torch.abs(p))
        return self.lambda_L1 * l1       
        
    def forward(self, data):
        data = data.float()
                
        img_features = self.van_encoder(data)
        
        pred = self.clasiffier(img_features)
        
        return pred
    
    def _shared_step(self, batch, val_step:bool=False):
        data, label = batch
        label = label.float()
        
        prediction = self(data)
        
        prediction = prediction.flatten()
        predicted_labels = torch.sigmoid(prediction)
        
        loss = self.lossFN_clasiffier(prediction, label)
        
        if not val_step:
            loss += self._calc_L1(self.clasiffier.parameters()) if self.lambda_L1 is not None else 0
        
        return loss, label, predicted_labels
        
    def training_step(self, batch, batch_idx):
        loss, true_labels, pred_labels = self._shared_step(batch, False)
        
        self.log("train_loss", loss, on_epoch=True)
        self.train_acc.update(pred_labels, true_labels)
        self.log("train_acc", self.train_acc, on_epoch=True, on_step=False)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, pred_labels = self._shared_step(batch, True)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.val_acc.update(pred_labels, true_labels)
        self.log("val_acc", self.val_acc, on_epoch=True, on_step=False, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        _, true_labels, predicted_labels = self._shared_step(batch, True)
        
        self.test_acc.update(predicted_labels, true_labels)
        self.f1ScoreTest.update(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc, on_epoch=True, on_step=False)
        self.log("test_f1", self.f1ScoreTest, on_epoch=True, on_step=False)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 
        
        return [optimizer]
    
if __name__ == "__main__":
    
    Dataset = Loaders.RARP_DatasetCreator(
        "./DataSet_Ando_All_no20Crop",
        FoldSeed=505,
        createFile=True,
        SavePath="./DataSet_AndoAll20_crop",
        Fold=5,
        removeBlackBar=False,
    )
    Dataset.mean, Dataset.std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    Dataset.CreateFolds()
            
    setup_seed(2023)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
    Fold = 0
    InitResize=(512, 512)

    
    numWorkers = 5
    MaxEpochs = 100
    LogFileName = "log_X21"

    rootFile = Dataset.CVS_File.parent.parent/f"fold_{Fold}"
    checkPtCallback = [
        callbk.ModelCheckpoint(monitor='val_acc', filename="RARP-{epoch}", save_top_k=10, mode='max'), 
        callbk.EarlyStopping(monitor="val_loss", mode="min", patience=5)
    ]
    
    traintransform = torch.nn.Sequential(
        transforms.Resize(InitResize, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.RandomErasing(0.8, value="random"),
        transforms.RandomAffine(degrees=(-15, 15), scale=(0.8, 1.2), fill=5),
        transforms.GaussianBlur(5),
        transforms.RandomCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)

    valtransform = torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)
    
    testtransform =  torch.nn.Sequential(
        transforms.Resize(256, antialias=True, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(Dataset.mean, Dataset.std)
    ).to(device)
    
    trainDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/"train"),
        loader = defs.load_file_tensor,
        extensions = "npy",
        transform = traintransform
    )

    valDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/"val"),
        loader = defs.load_file_tensor,
        extensions = "npy",
        transform = valtransform
    )
    
    testDataset = torchvision.datasets.DatasetFolder(
        str (rootFile/"test"),
        loader=defs.load_file_tensor,
        extensions="npy",
        transform=testtransform
    )
        
        
    def objective(trial):
        batch = trial.suggest_int("batch", 8, 32, step=2)
        lr = trial.suggest_float("lr", 4.10e-6, 1e-3, log=True)
        l1 = trial.suggest_float("l1", 1e-6, 1e-3, log=True)
        
        Train_DataLoader = DataLoader(
            trainDataset, 
            batch_size=batch, 
            num_workers=numWorkers, 
            shuffle=True, 
            drop_last=True,
            pin_memory=True,
            persistent_workers=numWorkers>0
        )
        Val_DataLoader = DataLoader(
            valDataset, 
            batch_size=batch, 
            num_workers=numWorkers, 
            shuffle=False, 
            pin_memory=True,
            persistent_workers=numWorkers>0
        )
        
        Model = RARP_VAN("van_b2_teacher_98.pth", clasiffier_layers=[128, 8], lr=lr, lambda_L1=l1)
    
        trainer = L.Trainer(
            deterministic=True,
            accelerator='gpu', 
            devices=1, 
            logger=CSVLogger(save_dir=f"./{LogFileName}", name="Tune"),
            log_every_n_steps=5,   
            callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc"), callbk.EarlyStopping(monitor="val_loss", mode="min", patience=7)],
            max_epochs=MaxEpochs,
        )
        
        hyperParams = dict(batch=batch, lr=lr, l1=l1)
        trainer.logger.log_hyperparams(hyperParams)
        print("Train Phase")
        
        trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)
        
        return trainer.callback_metrics["val_acc"].item()
    
    pruner = optuna.pruners.SuccessiveHalvingPruner()#MedianPruner()
    sampler = optuna.samplers.GPSampler(seed=2023)
    study = optuna.create_study(direction="maximize", pruner=pruner, sampler=sampler)
    study.optimize(objective, n_trials=100)
            
    print("Number of finished trials: {}".format(len(study.trials)))
    print("Best trial:")
    trial = study.best_trial
    print(f"   Value: {trial.value}")
    print(f"   Paramas: ")
    for key, val in trial.params.items():
        print(f"      {key}: {val}")

    
    #trainer.test(Model, dataloaders=Test_DataLoader, ckpt_path="best")