paper链接: https://arxiv.org/pdf/2109.14382.pdf
本文提出了单元强制操作Vision Transformer(UFO-ViT),这是一种具有线性复杂度的新型SA机制。这项工作的主要方法是从原来的SA。我们分解了SA机构的矩阵乘法消除非线性,没有复杂的线性逼近。仅修改原始SA中的几行代码,所提出的模型在大多数图像分类和密集预测任务上优于基于Transformer的模型。
原始的自注意(SA)机制尽管取得了巨大的成功,但由于 σ ( Q K T ) ∈ R N × N σ(QK^T)∈R^{N×N} σ(QKT)∈RN×N和V的矩阵乘法,其时间和计算复杂度为 O ( n 2 ) O(n^2) O(n2)。这是传统Transformer的缺点之一。对于视觉任务,N与输入分辨率成正比。这意味着如果输入图像的宽度和高度加倍,SA将消耗16倍的计算资源。
1、提出了一种新的约束方案XNorm,它生成一个单元来提取关系特征。该方案可以防止SA依赖于初始化。此外,通过替换softmax函数,消除了SA的非线性。
2、经验表明,UFO-ViT模型具有更快的推理速度和更少的GPU内存需求。对于不同的分辨率,所需的计算资源并没有显著增加。此外,模型中使用的权重与分辨率无关。这对于密集的预测任务(如目标检测和语义分割)是一个有用的特征。大多数密集预测任务需要比预训练阶段更高的分辨率,即基于mlp的结构需要额外的后处理以适应各种分辨率。
本文模型结构如下所示。它混合了卷积层、UFO模块和简单的前馈MLP层。
对于输入 x ∈ R N × C x∈R^{N×C} x∈RN×C,传统SA机制表述如下:
其中A表示注意算子。如果消除了softmax的非线性, σ ( Q K T ) V σ(QK^T)V σ(QKT)V可分解为 O ( N × h + h × N ) O(N × h + h × N) O(N×h+h×N)。本文使用XNorm代替softmax,它允许SA模块首先计算 K T V K^TV KTV。
XNorm的定义如下:
其中 γ γ γ是一个可学习的参数,h是嵌入维数。它是一个常见的l2范数,但它是沿着两个维度应用的: K T V K^TV KTV的空间维度和q的通道维度。因此,它被称为交叉归一化。
使用结合律,键和值首先相乘,然后查询相乘。下图描述了这一点。这两个乘法运算的复杂度都是 O ( h N d ) O(hNd) O(hNd),所以这个过程对N是线性的。
在XNorm中,自注意力的键和值直接相乘。通过线性核的方法生成h个聚类:
XNorm应用于查询和输出。
其中x表示输入。最后,投影权重使用加权和缩放和聚集点积项。
在这个公式中,关系特征是由嵌入块和簇之间的余弦相似度定义的。XNorm将查询和聚类中的每个像素的特征限制为单位向量。这可以防止它们的值通过将它们正则化为有限的长度来抑制关系属性。如果它们具有任意值,则注意区域依赖于初始化。
残差连接,任意一个模块的输出公式如下:
其中n和x分别表示当前层和输入图像的索引。假设x为某物体的位移,n为时间,则上式可以重新定义为:
大多数神经网络是离散的,因此∆t是常数。(为简便起见,设∆t = 1。)残差项表示速度,所以当粒子有单位质量且∆t = 1时,这一项表示权重项。
在物理学中,胡克定律被定义为弹性向量k和位移向量x的点积。弹性力产生谐波势U,是x2的函数。物理上,势能会干扰粒子运动的路径。(想象一个球在抛物线轨道上运动。)
以上公式一般用于近似分子在x≈0处的势能。对于多个分子,可以利用弹性线性度:
XNorm不是规范化,而是约束。这就是为什么被称为单位权重操作,或简称为UFO。
大多数其他归一化方法都无法训练。有趣的是,单一l2范数的应用也表现出较差的性能。所有结果见表3。
import numpy as np
import torch
from torch import nn
from torch.functional import norm
from torch.nn import init
def XNorm(x,gamma):
norm_tensor=torch.norm(x,2,-1,True)
return x*gamma/norm_tensor
class UFOAttention(nn.Module):
'''
Scaled dot-product attention
'''
def __init__(self, d_model, d_k, d_v, h,dropout=.1):
'''
:param d_model: Output dimensionality of the model
:param d_k: Dimensionality of queries and keys
:param d_v: Dimensionality of values
:param h: Number of heads
'''
super(UFOAttention, self).__init__()
self.fc_q = nn.Linear(d_model, h * d_k)
self.fc_k = nn.Linear(d_model, h * d_k)
self.fc_v = nn.Linear(d_model, h * d_v)
self.fc_o = nn.Linear(h * d_v, d_model)
self.dropout=nn.Dropout(dropout)
self.gamma=nn.Parameter(torch.randn((1,h,1,1)))
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.h = h
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, queries, keys, values):
b_s, nq = queries.shape[:2]
nk = keys.shape[1]
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
kv=torch.matmul(k, v) #bs,h,c,c
kv_norm=XNorm(kv,self.gamma) #bs,h,c,c
q_norm=XNorm(q,self.gamma) #bs,h,n,c
out=torch.matmul(q_norm,kv_norm).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)
out = self.fc_o(out) # (b_s, nq, d_model)
return out
if __name__ == '__main__':
input=torch.randn(50,49,512)
ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)
output=ufo(input,input,input)
print(output.shape)
`
``