import torch
import torchmetrics.classification
import torchvision
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from lightning.pytorch import seed_everything
import lightning.pytorch.callbacks as callbk
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
import Loaders
import defs
import lightning as L
import van
import copy
class UNet_VAN_1(torch.nn.Module):
def _conv_block(self, in_ch, out_ch):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
def _hook_fn(self, module, input, output):
self.feature_maps.append(output)
def _register_encoder_hooks(self):
for layer in self.list_blocks:
self.hooks.append(layer.register_forward_hook(self._hook_fn))
def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hooks = []
self.feature_maps = []
self.encoder_base = van.van_b1(pretrained = True, num_classes = 0)
self.list_blocks = [
self.encoder_base.block1[-1],
self.encoder_base.block2[-1],
self.encoder_base.block3[-1],
self.encoder_base.block4[-1],
]
self._register_encoder_hooks()
self.upConv_0 = torch.nn.ConvTranspose2d(512, 320, kernel_size=2, stride=2)
self.decoder_0 = self._conv_block(640, 256)
self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder_1 = self._conv_block(256, 128) #1
self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder_2 = self._conv_block(128, 64) #2
self.last_conv = torch.nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
self.feature_maps = []
_, _, h, w = x.shape
_ = self.encoder_base(x) # forward to encoder and call hooks
encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps
decoder_l3 = self.upConv_0(btlneck)
decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
decoder_l3 = self.decoder_0(decoder_l3)
decoder_l2 = self.upConv_1(decoder_l3)
decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
decoder_l2 = self.decoder_1(decoder_l2)
decoder_l1 = self.upConv_2(decoder_l2)
decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
decoder_l1 = self.decoder_2(decoder_l1)
return self.last_conv(torch.nn.functional.interpolate(decoder_l1, size=(h, w), mode="bicubic", align_corners=False))
class UNet_RN18(torch.nn.Module):
def _conv_block(self, in_ch, out_ch):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
def _hook_fn(self, module, input, output):
self.feature_maps.append(output)
def _register_encoder_hooks(self):
for layer in self.list_blocks:
self.hooks.append(layer.register_forward_hook(self._hook_fn))
def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hooks = []
self.feature_maps = []
self.encoder_base = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
self.encoder_base.fc = torch.nn.Identity()
self.list_blocks = [
self.encoder_base.conv1,
self.encoder_base.layer1,
self.encoder_base.layer2,
self.encoder_base.layer3,
self.encoder_base.layer4
]
self._register_encoder_hooks()
self.upConv_0 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.upConv_1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.upConv_2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.upConv_3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.upConv_extra = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.decoder_0 = self._conv_block(512, 256) #0
self.decoder_1 = self._conv_block(256, 128) #1
self.decoder_2 = self._conv_block(128, 64) #2
self.decoder_3 = self._conv_block(32 + 64, 32) #3
self.decoder_extra = self._conv_block(16, 16)
self.last_conv = torch.nn.Conv2d(16, out_channels, kernel_size=1)
def forward(self, x):
self.feature_maps = []
_ = self.encoder_base(x) # forward to encoder and call hooks
encoder_l0, encoder_l1, encoder_l2, encoder_l3, btlneck = self.feature_maps
decoder_l3 = self.upConv_0(btlneck)
decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
decoder_l3 = self.decoder_0(decoder_l3)
decoder_l2 = self.upConv_1(decoder_l3)
decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
decoder_l2 = self.decoder_1(decoder_l2)
decoder_l1 = self.upConv_2(decoder_l2)
decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
decoder_l1 = self.decoder_2(decoder_l1)
decoder_l0 = self.upConv_3(decoder_l1)
decoder_l0 = torch.cat((decoder_l0, encoder_l0), dim=1)
decoder_l0 = self.decoder_3(decoder_l0)
decoder_last = self.upConv_extra(decoder_l0)
decoder_last = self.decoder_extra(decoder_last)
return self.last_conv(decoder_last)
class UNet(torch.nn.Module):
def _conv_block(self, in_ch, out_ch):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
def __init__(self, in_channels:int = 3, out_channels:int = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.encoder_0 = self._conv_block(in_channels, 64) #0
self.encoder_1 = self._conv_block(64, 128) #1
self.encoder_2 = self._conv_block(128, 256) #2
self.encoder_3 = self._conv_block(256, 512) #3
self.pooling = torch.nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = self._conv_block(512, 1024)
self.upConv_0 = torch.nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.upConv_1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.upConv_2 = torch.nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.upConv_3 = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder_0 = self._conv_block(1024, 512) #0
self.decoder_1 = self._conv_block(512, 256) #1
self.decoder_2 = self._conv_block(256, 128) #2
self.decoder_3 = self._conv_block(128, 64) #3
self.last_conv = torch.nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
encoder_l1 = self.encoder_0(x) #3 -> 64
encoder_l2 = self.encoder_1(self.pooling(encoder_l1)) #64 -> 128
encoder_l3 = self.encoder_2(self.pooling(encoder_l2)) #128 -> 256
encoder_l4 = self.encoder_3(self.pooling(encoder_l3)) #256 -> 512
btlneck = self.bottleneck(self.pooling(encoder_l4)) #512 -> 1024
decoder_l4 = self.upConv_0(btlneck) #1024 -> 512
decoder_l4 = torch.cat((decoder_l4, encoder_l4), dim=1)
decoder_l4 = self.decoder_0(decoder_l4) #(512 + 512) -> 512
decoder_l3 = self.upConv_1(decoder_l4) #512 -> 256
decoder_l3 = torch.cat((decoder_l3, encoder_l3), dim=1)
decoder_l3 = self.decoder_1(decoder_l3) #(256 + 256) -> 256
decoder_l2 = self.upConv_2(decoder_l3) #256 -> 128
decoder_l2 = torch.cat((decoder_l2, encoder_l2), dim=1)
decoder_l2 = self.decoder_2(decoder_l2) #(128 + 128) -> 128
decoder_l1 = self.upConv_3(decoder_l2) #128 -> 64
decoder_l1 = torch.cat((decoder_l1, encoder_l1), dim=1)
decoder_l1 = self.decoder_3(decoder_l1) #(64 + 64) -> 64
return self.last_conv(decoder_l1)
class RARP_NVB_Model(L.LightningModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = UNet_VAN_1(in_channels=3, out_channels=1)
self.lr = 1E-4
self.lossFN = torch.nn.BCEWithLogitsLoss()
self.train_IoU = torchmetrics.classification.BinaryJaccardIndex()
self.val_IoU = torchmetrics.classification.BinaryJaccardIndex()
def _compute_iou(self, preds, targets):
intersection = torch.sum((preds > 0.5) & (targets > 0.5))
union = torch.sum((preds > 0.5) | (targets > 0.5))
return (intersection / (union + 1e-6)).item()
def forward(self, data):
data = data.float()
pred = self.model(data)
return pred
def _shared_step(self, batch):
img, mask = batch
mask = mask.float()
mask = mask.unsqueeze(1)
prediction = self(img)
loss = self.lossFN(prediction, mask)
predicted_labels = torch.sigmoid(prediction)
return loss, mask, predicted_labels
def training_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.train_IoU.update(predicted_labels, true_labels)
self.log("train_loss", loss, on_epoch=True)
self.log("train_acc_IoU", self.train_IoU, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, true_labels, predicted_labels = self._shared_step(batch)
self.val_IoU.update(predicted_labels, true_labels)
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
self.log("val_acc_IoU", self.val_IoU, on_epoch=True, on_step=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
return optimizer
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
seed_everything(seed, workers=True)
torch.backends.cudnn.deterministic = True
def Denorlalize (img:torch.Tensor, std, mean):
ImgNumpy = img.numpy().transpose((1, 2, 0))
ImgNumpy = np.clip((std * ImgNumpy + mean) , 0, 1)
ImgNumpy = ImgNumpy[...,::-1].copy()
return ImgNumpy
def visualize_augmentations(dataset, idx=0, samples=5):
dataset = copy.deepcopy(dataset)
dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
_, ax = plt.subplots(nrows=samples, ncols=2, figsize=(10, 24))
for i in range(samples):
image, mask = dataset[idx]
ax[i, 0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
ax[i, 1].imshow(mask, interpolation="nearest")
ax[i, 0].set_title("Augmented image")
ax[i, 1].set_title("Augmented mask")
ax[i, 0].set_axis_off()
ax[i, 1].set_axis_off()
plt.tight_layout()
plt.show()
mean, std = ([30.38144216, 42.03988769, 97.8896116], [40.63141752, 44.26910074, 50.29294373])
setup_seed(2023)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batchSize = 16
numWorkers = 0
rootFile = Path("../dataset/ROI/")
traintransform = A.Compose(
[
A.Resize(256, 256, interpolation=cv2.INTER_CUBIC),
A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
A.Rotate(limit=(-30,30), p=0.9),
A.RandomGamma(p=0.7),
A.RandomFog(0.1, p=0.75),
A.RandomToneCurve(p=0.75),
A.Normalize(mean, std),
ToTensorV2()
]
)
valtransform = A.Compose(
[
A.Resize(256, 256, interpolation=cv2.INTER_CUBIC),
A.Normalize(mean, std),
ToTensorV2()
]
)
trainDataset = Loaders.RARP_DatasetFolder_RoiExtractor(
str(rootFile/"train"),
loader=defs.load_Img,
extensions="tiff",
transform=traintransform,
create_mask=True
)
valDataset = Loaders.RARP_DatasetFolder_RoiExtractor(
str(rootFile/"val"),
loader=defs.load_Img,
extensions="tiff",
transform=valtransform,
create_mask=True
)
Train_DataLoader = DataLoader(
trainDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=True,
persistent_workers=numWorkers>0
)
Val_DataLoader = DataLoader(
valDataset,
batch_size=batchSize,
num_workers=numWorkers,
shuffle=False,
pin_memory=True,
persistent_workers=numWorkers>0
)
Model = RARP_NVB_Model()
trainer = L.Trainer(
deterministic=True,
accelerator='gpu',
devices=1,
logger=TensorBoardLogger(save_dir="./logs_debug"),
log_every_n_steps=5,
callbacks=[callbk.ModelCheckpoint(monitor="val_acc_IoU", filename="RARP-{epoch}-{val_loss:.4f}", save_top_k=5, mode='max')],
max_epochs=50,
)
print("Train Phase")
trainer.fit(Model, train_dataloaders=Train_DataLoader, val_dataloaders=Val_DataLoader)