Newer
Older
RARP / test.py
@delAguila delAguila on 20 May 13 KB Video Extraf frame
import torch
import torchmetrics.classification
import torchvision
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import Loaders
import defs
import lightning as L
import van
import copy

class UNet_VAN_1(torch.nn.Module):
    def _conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True)
        )
    
    def _hook_fn(self, module, input, output):
        self.feature_maps.append(output)    
    
    def _register_encoder_hooks(self):
        for layer in self.list_blocks:
            self.hooks.append(layer.register_forward_hook(self._hook_fn))
    
    def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        
        self.hooks = []
        self.feature_maps = []
        
        self.encoder_base = van.van_b1(pretrained = True, num_classes = 0)
        
        self.list_blocks = [
            self.encoder_base.block1[-1],
            self.encoder_base.block2[-1],
            self.encoder_base.block3[-1],
            self.encoder_base.block4[-1],
        ]
        
        self._register_encoder_hooks()
        
        self.upConv_0 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2)
        self.decoder_0 = self._conv_block(640, 256)
        
        self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder_1 = self._conv_block(256, 128)     #1
        
        self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder_2 = self._conv_block(128, 64)      #2
                
        self.last_conv = torch.nn.Conv2d(64, out_channels, kernel_size=1)
                
    def forward(self, x):
        self.feature_maps = []
        
        _, _, h, w = x.shape
        
        _ = self.encoder_base(x) # forward to encoder and call hooks 
        
        encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps
                
        decoder_l3 = self.upConv_0(btlneck) 
        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) 
        decoder_l3 = self.decoder_0(decoder_l3)    
        
        decoder_l2 = self.upConv_1(decoder_l3) 
        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) 
        decoder_l2 = self.decoder_1(decoder_l2)  
        
        decoder_l1 = self.upConv_2(decoder_l2) 
        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) 
        decoder_l1 = self.decoder_2(decoder_l1)  
                                
        return self.last_conv(torch.nn.functional.interpolate(decoder_l1, size=(h, w), mode="bicubic", align_corners=False))

class UNet_RN18(torch.nn.Module):
    def _conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True)
        )
    
    def _hook_fn(self, module, input, output):
        self.feature_maps.append(output)    
    
    def _register_encoder_hooks(self):
        for layer in self.list_blocks:
            self.hooks.append(layer.register_forward_hook(self._hook_fn))
    
    def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        
        self.hooks = []
        self.feature_maps = []
        
        self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        self.encoder_base.fc = torch.nn.Identity()
        self.list_blocks = [
            self.encoder_base.conv1,
            self.encoder_base.layer1,
            self.encoder_base.layer2,
            self.encoder_base.layer3,
            self.encoder_base.layer4
        ]
        
        self._register_encoder_hooks()
        
        self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        
        self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)

        self.decoder_0 = self._conv_block(512, 256)     #0
        self.decoder_1 = self._conv_block(256, 128)     #1
        self.decoder_2 = self._conv_block(128, 64)      #2
        self.decoder_3 = self._conv_block(32 + 64, 32)  #3
        
        self.decoder_extra = self._conv_block(16, 16)
        
        self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1)
        
    def forward(self, x):
        self.feature_maps = []
        
        _ = self.encoder_base(x) # forward to encoder and call hooks 
        
        encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps
                
        decoder_l3 = self.upConv_0(btlneck) 
        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) 
        decoder_l3 = self.decoder_0(decoder_l3)    
        
        decoder_l2 = self.upConv_1(decoder_l3) 
        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) 
        decoder_l2 = self.decoder_1(decoder_l2)  
        
        decoder_l1 = self.upConv_2(decoder_l2) 
        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) 
        decoder_l1 = self.decoder_2(decoder_l1)  
        
        decoder_l0 = self.upConv_3(decoder_l1) 
        decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1) 
        decoder_l0 = self.decoder_3(decoder_l0)
        
        decoder_last = self.upConv_extra(decoder_l0)
        decoder_last = self.decoder_extra(decoder_last)
                
        return self.last_conv(decoder_last)

class UNet(torch.nn.Module):
    def _conv_block(self, in_ch, out_ch):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True)
        )
    
    def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.encoder_0 = self._conv_block(in_channels, 64)  #0
        self.encoder_1 = self._conv_block(64, 128)          #1
        self.encoder_2 = self._conv_block(128, 256)         #2
        self.encoder_3 = self._conv_block(256, 512)         #3
        
        self.pooling = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.bottleneck = self._conv_block(512, 1024)
        
        self.upConv_0 = torch.nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.upConv_1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upConv_2 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        
        self.decoder_0 = self._conv_block(1024, 512)    #0
        self.decoder_1 = self._conv_block(512, 256)     #1
        self.decoder_2 = self._conv_block(256, 128)     #2
        self.decoder_3 = self._conv_block(128, 64)      #3
        
        self.last_conv = torch.nn.Conv2d(64, out_channels, kernel_size=1)
        
    def forward(self, x):
        encoder_l1 = self.encoder_0(x) #3 -> 64
        encoder_l2 = self.encoder_1(self.pooling(encoder_l1)) #64 -> 128
        encoder_l3 = self.encoder_2(self.pooling(encoder_l2)) #128 -> 256
        encoder_l4 = self.encoder_3(self.pooling(encoder_l3)) #256 -> 512
        
        btlneck = self.bottleneck(self.pooling(encoder_l4)) #512 -> 1024
        
        decoder_l4 = self.upConv_0(btlneck) #1024 -> 512
        decoder_l4 = torch.cat((decoder_l4, encoder_l4), dim=1) 
        decoder_l4 = self.decoder_0(decoder_l4) #(512 + 512) -> 512     
        
        decoder_l3 = self.upConv_1(decoder_l4) #512 -> 256
        decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1) 
        decoder_l3 = self.decoder_1(decoder_l3) #(256 + 256) -> 256 
        
        decoder_l2 = self.upConv_2(decoder_l3) #256 -> 128
        decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1) 
        decoder_l2 = self.decoder_2(decoder_l2) #(128 + 128) -> 128 
        
        decoder_l1 = self.upConv_3(decoder_l2) #128 -> 64
        decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1) 
        decoder_l1 = self.decoder_3(decoder_l1) #(64 + 64) -> 64 
        
        return self.last_conv(decoder_l1)
    
    
class RARP_NVB_Model(L.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.model = UNet_VAN_1(in_channels=3, out_channels=1)
        
        self.lr = 1E-4
        self.lossFN = torch.nn.BCEWithLogitsLoss()
        
        self.train_IoU = torchmetrics.classification.BinaryJaccardIndex()
        self.val_IoU = torchmetrics.classification.BinaryJaccardIndex()
        
    def _compute_iou(self, preds, targets):
        intersection = torch.sum((preds > 0.5) & (targets > 0.5))
        union = torch.sum((preds > 0.5) | (targets > 0.5))
        return (intersection / (union + 1e-6)).item()
                
    def forward(self, data):
        data = data.float()
        pred = self.model(data)
        return pred
    
    def _shared_step(self, batch):
        img, mask = batch
        
        mask = mask.float()
        mask = mask.unsqueeze(1)
        prediction = self(img)
                
        loss = self.lossFN(prediction, mask)
        
        predicted_labels = torch.sigmoid(prediction)
            
        return loss, mask, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.train_IoU.update(predicted_labels, true_labels)

        self.log("train_loss", loss, on_epoch=True)
        self.log("train_acc_IoU", self.train_IoU, on_epoch=True, on_step=False)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        
        self.val_IoU.update(predicted_labels, true_labels)
        
        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_acc_IoU", self.val_IoU, on_epoch=True, on_step=False)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) 
        return optimizer

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    seed_everything(seed, workers=True)
    torch.backends.cudnn.deterministic = True
    
def Denorlalize (img:torch.Tensor, std, mean):
    ImgNumpy = img.numpy().transpose((1, 2, 0))
    ImgNumpy = np.clip((std * ImgNumpy + mean) , 0, 1)
    ImgNumpy = ImgNumpy[...,::-1].copy()
    
    return ImgNumpy

def visualize_augmentations(dataset, idx=0, samples=5):
    dataset = copy.deepcopy(dataset)
    dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
    _, ax = plt.subplots(nrows=samples, ncols=2, figsize=(10, 24))
    for i in range(samples):
        image, mask = dataset[idx]
        ax[i, 0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        ax[i, 1].imshow(mask, interpolation="nearest")
        ax[i, 0].set_title("Augmented image")
        ax[i, 1].set_title("Augmented mask")
        ax[i, 0].set_axis_off()
        ax[i, 1].set_axis_off()
        
    plt.tight_layout()
    plt.show()
    
mean, std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
    
setup_seed(2023)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batchSize = 16
numWorkers = 0
rootFile = Path("../dataset/ROI/")

traintransform = A.Compose(
    [
        A.Resize(256, 256, interpolation=cv2.INTER_CUBIC), 
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
        A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.Rotate(limit=(-30,30), p=0.9),
        A.RandomGamma(p=0.7),
        A.RandomFog(0.1, p=0.75),
        A.RandomToneCurve(p=0.75),
        A.Normalize(mean, std),
        ToTensorV2()
    ]
)

valtransform = A.Compose(
    [
        A.Resize(256, 256, interpolation=cv2.INTER_CUBIC),
        A.Normalize(mean, std),
        ToTensorV2()
    ]
)

trainDataset = Loaders.RARP_DatasetFolder_RoiExtractor(
    str(rootFile/"train"),
    loader=defs.load_Img,
    extensions="tiff",
    transform=traintransform,
    create_mask=True
)
valDataset = Loaders.RARP_DatasetFolder_RoiExtractor(
    str(rootFile/"val"),
    loader=defs.load_Img,
    extensions="tiff",
    transform=valtransform,
    create_mask=True
)

Train_DataLoader = DataLoader(
    trainDataset, 
    batch_size=batchSize, 
    num_workers=numWorkers, 
    shuffle=True, 
    persistent_workers=numWorkers>0
)

Val_DataLoader = DataLoader(
    valDataset, 
    batch_size=batchSize, 
    num_workers=numWorkers, 
    shuffle=False, 
    pin_memory=True,
    persistent_workers=numWorkers>0
)

Model = RARP_NVB_Model()

trainer = L.Trainer(
    deterministic=True,
    accelerator='gpu', 
    devices=1, 
    logger=TensorBoardLogger(save_dir="./logs_debug"),
    log_every_n_steps=5,   
    callbacks=[callbk.ModelCheckpoint(monitor="val_acc_IoU", filename="RARP-{epoch}-{val_loss:.4f}", save_top_k=5, mode='max')],
    max_epochs=50,
)
print("Train Phase")
trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)