阅读笔记 - The Devil in Linear Transformer

来源:https://www.researchgate.net/publication/364419868_The_Devil_in_Linear_Transformer
代码:https://github.com/OpenNLPLab/Transnormer


这篇文章的目的是优化线性transformer,线性transformer相对于标准transformer能够将计算复杂度从 降到. 但线性transformer 相对于标准transformer 往往存在着较明显的指标gap。作者分析认为原因有两点:

  • unbounded gradients。无边界梯度,会导致模型在训练时不稳定,收敛不好;
  • attention dilution。注意力稀释,transformer在lower level时应该更关注局部特征,而higher level更关注全局特征,但线性transformer中的attention往往weight 更均匀化,不能聚焦在local区域上,因此称为attention稀释。
    针对于上述两点,作者提出了NormAttention和DiagAttention两个模块,形成NormFormer的结构。

1.The devil in linear attention

我们首先来看一下作者分析的线性transformer存在的两点缺陷的结论是怎么来的。

1.1 Unbounded gradients

在标准的attention结构中

正是这里的 带来的的计算复杂度。而为了解决这个问题目前主要包含两类: 基于pattern的方法和基于kernel的方法。
基于pattern的方式主要是通过一些先验筛选key或query,降低计算复杂度;而基于kernel的方法则是本文提到的线性transformer,通过核函数去取代softmax,从而能够通过矩阵乘法结合律降低计算复杂度。
那么来看一下计算attention时,vanilla和linear transformer的统一形式:

对于vanilla transformer而言, , 对于linear transformer可以表示为 . 于是可以比较一下两者的梯度:
vanilla attention: , 这里推理的时候注意凑
f'(x) = \text{exp}(x) = f(x) \\ \frac{\partial p_{ij}}{\partial s_{ik}} = 1_{j=k}p_{ij} - p_{ij}p_{ik} \\ = \begin{cases} p_{ik} - p_{ij}p_{ik}\in [0, 1/4], &j=k \\ - p_{ij}p_{ik}\in [-1/4, 0],& j\neq k\end{cases}
这里推理的时候只有 时边界值成立,所以最终

linear attention: 线性attention的关键在于, 因此
f'(x) =1 \\ \frac{\partial p_{ij}}{\partial s_{ik}} = \frac{1}{s_{ik}} \big(1_{j=k}p_{ij} - p_{ij}p_{ik}\big) \\ = \frac{1}{s_{ik}}\begin{cases} p_{ik} - p_{ij}p_{ik}, &j=k \\ - p_{ij}p_{ik},& j\neq k\end{cases} 即,.
因为 大小是不确定的,所以相当于linear attention的梯度是无边界的。这就会导致收敛不稳定,收敛难度大等问题。

1.2 Attention dilution

注意力稀释方面,作者直接评估了不同level上,每一个query在邻域内的其他query上的attention的权重占比,这里需要注意的是,query之间是有序的,即对于NLP或者featmap而言,是有固定结构的,才可以这么评估。表示第i个query在其个邻域query上的attention之和,可以看下图,a图中transformer和linear transformer相比,显然linear transformer的聚集度要小很多。这就是所谓的注意力稀释。

image.png

2. architecture

针对于1中的两个问题,有针对性的设计了两个模块。

2.1 NormAttention.

作者提出的解决方案
,
这里的XNorm 可以是Layernorm,也可以是 RMSNorm。注意这里的Q,和K是有激活函数的,公式没写,但图中画了。

文章证明这个做法梯度是有上界的。附录的证明过程有点复杂。

2.2 DiagAttention

这个模块其实就是一种基于pattern的attention,将query按距离划分不重叠的window,每个window内进行 attention的计算。奇怪的是 这里的attention使用的都是vanilla attention。

下图是文章方法TransNormer的结构:


image.png

3. 实验

实验都是在NLP上做的,不大了解,因此不做分析,这里只看下消融实验的结论。

image.png

table8. 表明早期的stage应当更关注局部特征,而后期的stage则应该更关注全局信息。
table9. 早期适合使用blockattn,后期适合使用normattn
table10. FFN中作者对比了FFN和GLU的结果,发现GLU效果会更好一些。
image.png

table11.表明diagattn中的window的大小,这个其实有有点说不通,如果DiagAttn使用的linear attention, block size越大不是attention 稀释的越严重吗? 这个地方DiagAttn使用的应该都是vanilla attention,包括softmax attention和ReLA attention.

4. 结论

本文提出的norm attention其实在很多其他方法中都见过,而且所谓的diag attention使用的还是vanilla attention,并没有把linear attention应用到diag block里,感觉不是很充实。值得学习的是本文中提出的梯度分析的方法。

你可能感兴趣的:(阅读笔记 - The Devil in Linear Transformer)