import torch
import torch.nn as nn
class NOAH(nn.Module):
def __init__(self, inplanes, outplanes, dropout=0.0, key_ratio=0.5, head_num=1, head_split=True, kv_split=True):
super(NOAH, self).__init__()
self.kv_split = kv_split
self.head_split = head_split
self.dropout = nn.Dropout(p=dropout)
self.key_ratio = key_ratio
self.head_num = head_num
if kv_split:
self.k_channel = int(inplanes * key_ratio)
self.v_channel = inplanes - self.k_channel
else:
self.k_channel = inplanes
self.v_channel = inplanes
assert self.k_channel % head_num == 0
assert self.v_channel % head_num == 0
self.groups = head_num if head_split else 1
self.query = nn.Conv2d(self.k_channel, head_num * outplanes, kernel_size=1, groups=self.groups,
stride=1, padding=0)
self.value = nn.Conv2d(self.v_channel, head_num * outplanes, kernel_size=1, groups=self.groups,
stride=1, padding=0)
#self._init_weight()
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
x = torch.flatten(x, 2).unsqueeze(dim=-2)
N, C, _, L = x.shape
if self.kv_split:
a = torch.softmax(self.query(x[:, :self.k_channel]).reshape(N, self.head_num, -1, L), dim=3)
v = self.value(x[:, self.k_channel:]).reshape(N, self.head_num, -1, L)
else:
a = torch.softmax(self.query(x).reshape(N, self.head_num, -1, L), dim=3)
v = self.value(x).reshape(N, self.head_num, -1, L)
v = self.dropout(v)
x = torch.sum(a * v, dim=(1, 3))
return x