所有样本
之间的相关性先来看一下普通的 self-attention 模块是怎样操作的:
给到一个尺寸为 F ∈ R N × d F \in \mathbb{R}^{N \times d} F∈RN×d的输入,此处N为像素的数量,d为特征维度(feature dimensions)的数量,普通的self-attention首先将会把输入线性地映射到一个query矩阵 Q ∈ R N × d ′ Q \in \mathbb{R}^{N \times d^{\prime}} Q∈RN×d′,一个key矩阵 K ∈ R N × d ′ K \in \mathbb{R}^{N \times d^{\prime}} K∈RN×d′,和一个value矩阵 V ∈ R N × d V \in \mathbb{R}^{N \times d} V∈RN×d。接下来,用以下的式子得到最终的结果:
A = ( α ) i , j = softmax ( Q K T ) F out = A V \begin{aligned}A &=(\alpha)_{i, j}=\operatorname{softmax}\left(Q K^{T}\right) \\F_{\text {out }} &=A V\end{aligned} AFout =(α)i,j=softmax(QKT)=AV
上式中的 A ∈ R N × N A \in \mathbb{R}^{N \times N} A∈RN×N即为注意力矩阵,矩阵中的 α i , j \alpha_{i, j} αi,j是第i个像素点和第j个像素点之间的相似度
由此看来,之前的工作基本都是用的图片块(image patch),而不是所有的像素,是因为不这样的话对计算力的要求实在太大
当把注意力图给可视化出来后,注意到大部分的像素其实只与很少的几个像素之间有很强的相关性,因此一个 N x N 的注意力矩阵实在太过于冗余了
于是,作者提出以下的方法,用external memory来计算注意力矩阵,计算得到的是输入像素与这个external memory 之间的attention。external memory 的尺寸是 M ∈ R S × d M \in \mathbb{R}^{S \times d} M∈RS×d。此处 ( α ) i , j (\alpha)_{i, j} (α)i,j代表的是第 i i i个像素和记性模块 M M M第 j j j行的相似度, M M M是个可学习的参数,且与输入独立(不相关),充当的是整个数据集的记忆模块.
A = ( α ) i , j = Norm ( F M T ) F out = A M \begin{aligned}A &=(\alpha)_{i, j}=\operatorname{Norm}\left(F M^{T}\right) \\F_{\text {out }} &=A M\end{aligned} AFout =(α)i,j=Norm(FMT)=AM
实际使用中,我们用的是两个不同的memory模块,称为 M k M_k Mk和 M v M_v Mv,前者是key,后者是value,以达到提升网络能力的目的,计算如下:
A = Norm ( F M k T ) F out = A M v \begin{aligned}A &=\operatorname{Norm}\left(F M_{k}^{T}\right) \\F_{\text {out }} &=A M_{v}\end{aligned} AFout =Norm(FMkT)=AMv
这样,我们的算法就是和像素的数量呈线性相关的了,复杂度为 O ( d S N ) \mathcal{O}(d S N) O(dSN),其中 d d d和 S S S为超参。而且试验中发现,即使S的值很少,比如设为64,也有很好的效果。
# Input: F, an array with shape [B, N, C] (batch size, pixels, channels)
# Parameter: M_k, a linear layer without bias
# Parameter: M_v, a linear layer without bias
# Output: out, an array with shape [B, N, C]
attn = M_k(F) # shape=(B, N, M)
attn = softmax(attn, dim=1)
attn = l1_norm(attn, dim=2)
out = M_v(attn) # shape=(B, N, C)
( α ~ ) i , j = F M k T α i , j = exp ( α ~ i , j ) ∑ k exp ( α ~ k , j ) α i , j = α i , j ∑ k α i ^ , k \begin{aligned}(\tilde{\alpha})_{i, j} &=F M_{k}^{T} \\\alpha_{i, j} &=\frac{\exp \left(\tilde{\alpha}_{i, j}\right)}{\sum_{k} \exp \left(\tilde{\alpha}_{k, j}\right)} \\\alpha_{i, j} &=\frac{\alpha_{i, j}}{\sum_{k} \alpha_{\hat{i}, k}}\end{aligned} (α~)i,jαi,jαi,j=FMkT=∑kexp(α~k,j)exp(α~i,j)=∑kαi^,kαi,j
https://github.com/MenghaoGuo/-EANet
class External_attention(nn.Module):
'''
Arguments:
c (int): The input and output channel number.
'''
def __init__(self, c):
super(External_attention, self).__init__()
self.conv1 = nn.Conv2d(c, c, 1)
self.k = 64
self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)
self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)
self.conv2 = nn.Sequential(
nn.Conv2d(c, c, 1, bias=False),
norm_layer(c))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv1d):
n = m.kernel_size[0] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, _BatchNorm):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
idn = x
x = self.conv1(x)
b, c, h, w = x.size()
n = h*w
x = x.view(b, c, h*w) # b * c * n
attn = self.linear_0(x) # b, k, n
attn = F.softmax(attn, dim=-1) # b, k, n
attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) # # b, k, n
x = self.linear_1(attn) # b, c, n
x = x.view(b, c, h, w)
x = self.conv2(x)
x = x + idn
x = F.relu(x)
return x