diff --git a/RARP_ROI_Mask.ckpt b/RARP_ROI_Mask.ckpt new file mode 100644 index 0000000..0fd7f76 --- /dev/null +++ b/RARP_ROI_Mask.ckpt Binary files differ diff --git a/defs.py b/defs.py index 1b1b694..21aa8cf 100644 --- a/defs.py +++ b/defs.py @@ -12,7 +12,7 @@ return torch.from_numpy(np.load(path).astype(float).transpose((2, 0, 1))) / 255.0 def load_Img(path): - cv2.imread(str(path), cv2.IMREAD_COLOR) + return cv2.imread(str(path), cv2.IMREAD_COLOR) def clip_gradients(model, clip=2.0): """Rescale norm of computed gradients. diff --git a/imageViewer.py b/imageViewer.py index aec8bd1..db1c1b3 100644 --- a/imageViewer.py +++ b/imageViewer.py @@ -2,6 +2,7 @@ import numpy as np from ultralytics import YOLO as YOLOv8 import PySimpleGUI as sg +from pathlib import Path def RemoveBlackBorder(image, shrink=True): image = np.array(image) @@ -116,7 +117,7 @@ cv2.waitKey(0) cv2.destroyAllWindows() case "-verImg-": - image = np.load(vals["-Path-"]) if vals["-NP_Array_IMG-"] else cv2.imread(vals["-Path-"], cv2.IMREAD_COLOR) + image = np.load(Path(vals["-Path-"])) if vals["-NP_Array_IMG-"] else cv2.imread(vals["-Path-"], cv2.IMREAD_COLOR) cv2.namedWindow("Output", cv2.WINDOW_NORMAL) cv2.resizeWindow("Output", (1024, 720)) cv2.imshow("Output", image) diff --git a/noah.py b/noah.py new file mode 100644 index 0000000..4adcf25 --- /dev/null +++ b/noah.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + + +class NOAH(nn.Module): + def __init__(self, inplanes, outplanes, dropout=0.0, key_ratio=0.5, head_num=1, head_split=True, kv_split=True): + super(NOAH, self).__init__() + self.kv_split = kv_split + self.head_split = head_split + self.dropout = nn.Dropout(p=dropout) + self.key_ratio = key_ratio + self.head_num = head_num + + if kv_split: + self.k_channel = int(inplanes * key_ratio) + self.v_channel = inplanes - self.k_channel + else: + self.k_channel = inplanes + self.v_channel = inplanes + + assert self.k_channel % head_num == 0 + assert self.v_channel % head_num == 0 + + self.groups = head_num if head_split else 1 + self.query = nn.Conv2d(self.k_channel, head_num * outplanes, kernel_size=1, groups=self.groups, + stride=1, padding=0) + self.value = nn.Conv2d(self.v_channel, head_num * outplanes, kernel_size=1, groups=self.groups, + stride=1, padding=0) + #self._init_weight() + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + x = torch.flatten(x, 2).unsqueeze(dim=-2) + N, C, _, L = x.shape + if self.kv_split: + a = torch.softmax(self.query(x[:, :self.k_channel]).reshape(N, self.head_num, -1, L), dim=3) + v = self.value(x[:, self.k_channel:]).reshape(N, self.head_num, -1, L) + else: + a = torch.softmax(self.query(x).reshape(N, self.head_num, -1, L), dim=3) + v = self.value(x).reshape(N, self.head_num, -1, L) + v = self.dropout(v) + x = torch.sum(a * v, dim=(1, 3)) + return x +