Attention Free Transformer(AFT)

Attention Free Transformer(AFT)

paper: An Attention Free Transformer

date: 2021-05

org: Apple

1 Motivation

原本基于dot product self attention Transformer的时间复杂度和空间复杂度都很高。提出了一个新的AFT层来降低transformer的计算量。

Attention Free Transformer(AFT)_第1张图片

2 Method

2.1 Multi-Head Attention回顾

首先回顾一下经典的Multi-Head Attention(MHA),每一个head的计算如下

f i ( X ) = σ ( Q i ( K i ) T d k ) V i ,   s . t .   Q i = X W i Q , K i = X W i K , V i = X W i V , (1) f _ { i } ( X ) = \sigma ( \frac { Q _ { i } ( K _ { i } ) ^ { T } } { \sqrt { d _ { k } } } ) V _ { i } , \ \mathrm { s . t . } \ Q _ { i } = X W _ { i } ^ { Q } , K _ { i } = X W _ { i } ^ { K } , V _ { i } = X W _ { i } ^ { V } , \tag{1} fi(X)=σ(dk Qi(Ki)T)Vi, s.t. Qi=XWiQ,Ki=XWiK,Vi=XWiV,(1)

其中: W i Q    ∈    R d × d k , W i K    ∈    R d × d k , W i V    ∈    R d × d υ W _ { i } ^ { Q } \; \in \; R ^ { d \times d _ { k } } , W _ { i } ^ { K } \; \in \; R ^ { d \times d _ { k } } , W _ { i } ^ { V } \; \in \; R ^ { d \times d _ { \upsilon } } WiQRd×dk,WiKRd×dk,WiVRd×dυ σ \sigma σ是非线性函数,默认为 s o f t m a x softmax softmax。通常情况下 d v = d k , h = d d k d_v = d_k, h = \frac{d}{d_k} dv=dk,h=dkd。假定输入 X ∈ R T × d X \in \mathbb {R}^ {T \times d} XRT×d, 经过 f i f_i fi转化后的输出 f i ( X ) ∈ R T × d v f_i{(X)} \in \mathbb{R} ^{T \times d_v} fi(X)RT×dv。将所有head的结果拼接起来得到最后的输出 R T × d \mathbb{R} ^{T \times d} RT×d

单头Attention的时间复杂度计算:

  • Q K V QKV QKV 的计算,此处有3个矩阵乘法,计算量为 d × d k × T × 3 d \times d_k \times T \times 3 d×dk×T×3, 时间复杂度为: O ( 1 h T d 2 ) \mathcal{O}(\frac{1}{h}Td^2) O(h1Td2)
  • Q K T QK^T QKT的计算,计算量为: d k × T × T d_k \times T \times T dk×T×T, 时间复杂度为: O ( 1 h T 2 d ) \mathcal{O}(\frac{1}{h}T^2d) O(h1T2d)
  • scale 的计算量为: T × T T \times T T×T, 时间复杂度为: O ( T 2 ) \mathcal{O}(T^2) O(T2)
  • softmax的计算量为: T × T T \times T T×T, 时间复杂度为: O ( T 2 ) \mathcal{O}(T^2) O(T2)
  • 最后加权乘法计算量为 d k × T × T d_k \times T \times T dk×T×T,时间复杂度为: O ( 1 h T 2 d ) \mathcal{O}(\frac{1}{h}T^2d) O(h1T2d)

对于MHA,时间复杂度为 O ( T d 2 ) \mathcal{O}(Td^2) O(Td2)

2.2 Attention Free Transofrmer(AFT)

2.2.1 AFT full

第一步和MHA一样,输入 X X X经过三个linear transfer得到 Q K V QKV QKV,3个矩阵, 维度为 R T × d \mathbb{R}^{T \times d} RT×d。AFT引入了一个新的可训练参数矩阵 w ∈ R T × T w \in \mathbb{R}^{T \times T} wRT×T,论文将其称之为可学习的一对一位置偏置(learned pair-wise position biases)。

Attention Free Transformer(AFT)_第2张图片

我们以 y t y_t yt 为视角看每一步的具体流程。

SETP1: w e i g h t e d ( K ( t ) ) \mathrm{weighted}(K^{(t)}) weighted(K(t))。从 w w w t = t t=t t=t的向量, 和 K K K做点乘后以列方向进行 s o f t m a x \mathrm{softmax} softmax。该步骤的计算复杂度为 O ( T × d ) \mathcal{O}(T \times d) O(T×d)

W e i g h t e d ( K ( t ) ) = exp ⁡ ( K + w t ) ∑ i = 1 T exp ⁡ ( k i + w t i ) (2) \mathrm{Weighted}(K^{(t)}) = \frac{\exp (K + w_t ) }{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \tag{2} Weighted(K(t))=i=1Texp(ki+wti)exp(K+wt)(2)

Attention Free Transformer(AFT)_第3张图片

STEP2: 求 A t t e n t i o n ( t ) \mathrm{Attention}^{(t)} Attention(t)矩阵。将q_t用sigmoid变换后,点乘wighted(K)。该步骤的计算复杂度为 O ( T × d ) \mathcal{O}(T \times d) O(T×d)

A t t e n t i o n ( t ) = σ ( q t ) ⊙ W e i g h t e d ( K ( t ) ) = σ ( q t ) ⊙ exp ⁡ ( K + w t ) ∑ i = 1 T exp ⁡ ( k i + w t i ) (3) \mathrm{Attention^{(t)}} = \sigma(q_t) \odot \mathrm{Weighted}(K^{(t)})= \frac{\sigma(q_t) \odot \exp (K + w_t ) }{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \tag{3} Attention(t)=σ(qt)Weighted(K(t))=i=1Texp(ki+wti)σ(qt)exp(K+wt)(3)
Attention Free Transformer(AFT)_第4张图片

STEP3: 计算 y t y_t yt。该步骤的计算复杂度为 O ( T × d ) \mathcal{O}(T \times d) O(T×d)

y t = ∑ i = 1 T ( A t t e n t i o n ( t ) i ⊙ v i ) = ∑ i = 1 T σ ( q t ) ⊙ exp ⁡ ( k i + w t ) ∑ i = 1 T exp ⁡ ( k i + w t i ) ⊙ v i (4) y_t = \sum_{i=1}^{T}(\mathrm{Attention^{(t)}}_i \odot v_i) = \sum_{i=1}^{T} \frac{\sigma(q_t) \odot \exp (k_i + w_t ) }{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \odot v_i \tag{4} yt=i=1T(Attention(t)ivi)=i=1Ti=1Texp(ki+wti)σ(qt)exp(ki+wt)vi(4)

Attention Free Transformer(AFT)_第5张图片

对式(4)稍做变形,可得论文中的计算公式

y t = σ ( q t ) ⊙ ∑ i = 1 T exp ⁡ ( k i + w t ) ⊙ v i ∑ i = 1 T exp ⁡ ( k i + w t i ) (5) y_t = \sigma(q_t)\odot \frac{ \sum_{i=1}^{T}\exp (k_i + w_t ) \odot v_i}{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \tag{5} yt=σ(qt)i=1Texp(ki+wti)i=1Texp(ki+wt)vi(5)

将所有的步骤串起来的流程如下。可以看到AFT其实也用到了attention的思想。但AFT中的Attention Score的计算并没有用到矩阵乘法,只用到了向量点乘。虽整体的计算复杂度仍然是 O ( T 2 d ) \mathcal{O}(T^2d) O(T2d),但计算量已有所下降。

式(4)计算pipeline

Attention Free Transformer(AFT)_第6张图片

式(5)计算pipeline

Attention Free Transformer(AFT)_第7张图片

2.2.1 AFT local

在许多情况下,局部性是一个很重要的归纳偏置(inductive bias),而标准的Transformer的计算中没有引入局部信息。因此,作者提出AFT-local。其形式与AFT-Full一致。区别在于,引入了下式限制

w t , t ′ = { w t , t ′ , i f ∣ t − t ′ ∣ < s 0 , o t h e r w i s e . (6) w_{t, t'} = \begin{cases} w_{t, t'}, \quad \mathrm{if} |t - t'| < s \\ 0, \quad \mathrm{otherwise.}\end{cases} \tag{6} wt,t={wt,t,iftt<s0,otherwise.(6)

式中的 s s s就是定义的局部窗口大小(local window size)。它进一步降低了计算量。变换后的 w w w如下图所示(此时 s = 2 s=2 s=2, 黑色方块为0)。

Attention Free Transformer(AFT)_第8张图片

2.2.2 AFT simple

AFT simple是AFT local当 s = 0 s = 0 s=0时的特殊形式。此时没有位置偏置。可将式5化简为,因为对不同的 t t t ∑ i = 1 T ( s o f t m a x ( K ) ⊙ V ) i \sum_{i=1}^{T} (\mathrm{softmax}(K) \odot V)_{i} i=1T(softmax(K)V)i都是相同的。AFT simple的时间复杂度为 O ( T d ) \mathcal{O}(Td) O(Td)

y t = σ ( q t ) ⊙ ∑ i = 1 T exp ⁡ ( k i ) ⊙ v i ∑ i = 1 T exp ⁡ ( k i ) = σ ( q t ) ⊙ ∑ i = 1 T ( s o f t m a x ( K ) ⊙ V ) i (6) y_t = \sigma(q_t)\odot \frac{ \sum_{i=1}^{T}\exp (k_i) \odot v_i}{\sum_{i=1}^{T} \exp (k_i) } = \sigma(q_t)\odot \sum_{i=1}^{T} (\mathrm{softmax}(K) \odot V)_{i}\tag{6} yt=σ(qt)i=1Texp(ki)i=1Texp(ki)vi=σ(qt)i=1T(softmax(K)V)i(6)

2.2.3 AFT conv

作者进一步将局部性的思想扩展到空间权重共享(如卷积),提出AFT-conv。具体来说,让 w t , t ′ w_{t,t'} wt,t的值仅依赖 t t t t ′ t' t的相对位置。为了考虑参数数量随着 h e a d head head数增加而增长的情况,作者采用了一个设计选择,将 K K K的维度与head数绑定在一起(MHA的思路)。这使得AFT-conv可以采用深度可分离卷积、全局池化和element-wise操作的实现方式。

可以看到与AFT simple相比,AFT conv引入了head思想,并通过1维卷积的计算结果引入局部信息。其形式与式(6)相比分子分母中新增了 c o n v 1 d ( exp ⁡ ( K j ) ⊙ V j ,    exp ⁡ ( w j )   − 1 ) \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) \odot V ^ { j } , \; \exp ( w ^ { j } ) \, - 1 ) conv1d(exp(Kj)Vj,exp(wj)1) c o n v 1 d ( exp ⁡ ( K j ) ,    exp ⁡ ( w j )    − 1 ) \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) , \; \exp ( w ^ { j } ) \; - 1 ) conv1d(exp(Kj),exp(wj)1)。(上标 j j j表示第 j j j个head)。此时的 w w w为conv1d的filter。

y t j = σ q ( q t j ) ⊙ c o n v 1 d ( exp ⁡ ( K j ) ⊙ V j ,    exp ⁡ ( w j )   − 1 ) + ∑ i = 1 T exp ⁡ ( k i j ) ⊙ v i j c o n v 1 d ( exp ⁡ ( K j ) ,    exp ⁡ ( w j )    − 1 ) + ∑ i = 1 T exp ⁡ ( k i j ) (7) y _ { t } ^ { j } = \sigma _ { q } ( q _ { t } ^ { j } ) \odot \frac { \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) \odot V ^ { j } , \; \exp ( w ^ { j } ) \, - 1 ) + \sum _ { i = 1 } ^ { T } \exp ( k _ { i } ^ { j } ) \odot v _ { i } ^ { j } } { \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) , \; \exp ( w ^ { j } ) \; - 1 ) + \sum _ { i = 1 } ^ { T } \exp ( k _ {i } ^ { j } ) } \tag{7} ytj=σq(qtj)conv1d(exp(Kj),exp(wj)1)+i=1Texp(kij)conv1d(exp(Kj)Vj,exp(wj)1)+i=1Texp(kij)vij(7)

从ViT可视化attention map中可以看出(横轴为head, 纵轴为layer)。原本的ViT(左边)的不同层,head的attention map的响应最大区域基本都是中心区域。而用了AFT-conv后,不同层、head的attention都有所不同,有助于模型捕获不同尺度的特征。

Attention Free Transformer(AFT)_第9张图片

3 小结

本文提出了一种Dot Product Attention Free的Transformer,最多能将transofmer的时间复杂度从 O ( T 2 d ) \mathcal{O}(T^2d) O(T2d)降低到 O ( T d ) \mathcal{O}(Td) O(Td)(AFT-simple)。

你可能感兴趣的:(transformer,论文学习,transformer,深度学习,人工智能)