Transformer 中 Self-attention 的计算复杂度

在 Transformer 中,Multi-head attention 的计算过程是: MultiHeadAttn ( z q , x ) = ∑ m = 1 M W m [ ∑ k ∈ Ω k A m q k ⋅ W m ′ x k ] \text{MultiHeadAttn}(z_q, \mathbb{x}) = \sum_{m=1}^M W_m[\sum_{k\in \Omega_k} A_{mqk} \cdot {W'_m} \mathbb{x}_k] MultiHeadAttn(zq,x)=m=1MWm[kΩkAmqkWmxk].

其中 m m m是 attention head 的索引, W m ′ ∈ R C v × C {W_m}'\in \mathbb{R}^{C_v\times C} WmRCv×C 是输入的映射矩阵, W m ∈ R C × C v {W_m}\in \mathbb{R}^{C\times C_v} WmRC×Cv 是输出的映射矩阵,二者都是可学习的权重( C v = C / M C_v = C/M Cv=C/M)。Attention 权重 A m q k ∝ exp ⁡ { z q T U m T V m x k C v } A_{mqk}\propto \exp\lbrace \frac{z_q^T U_m^T V_m x_k}{\sqrt{C_v}} \rbrace Amqkexp{Cv zqTUmTVmxk},并且 ∑ k ∈ Ω k A m q k = 1 \sum_{k\in \Omega_{k}} A_{mqk}=1 kΩkAmqk=1,其中 U m , V m ∈ R C v × C U_m,V_m \in \mathbb{R}^{C_v\times C} Um,VmRCv×C分别是 query 的映射矩阵和 key 的映射矩阵,也都是可学习权重。设 query 和 key 元素的个数分别是 N q N_q Nq N k N_k Nk. MultiHeadAttn ( z q , x ) \text{MultiHeadAttn}(z_q, \mathbb{x}) MultiHeadAttn(zq,x)的计算复杂度是 O ( N q C 2 + N k C 2 + N q N k C ) O(N_q C^2 + N_k C^2 + N_q N_k C) O(NqC2+NkC2+NqNkC)

  1. 输入是 X ∈ R N × C X\in \mathbb{R}^{N\times C} XRN×C,用 U m , V m ∈ R C v × C U_m,V_m \in \mathbb{R}^{C_v\times C} Um,VmRCv×C分别对 query 和 key 做线性变换,计算得到 Q , K ∈ R N × C Q,K\in \mathbb{R}^{N\times C} Q,KRN×C矩阵。这样,计算 Q Q Q K K K的复杂度就是 O ( N q × C 2 ) O(N_q\times C^2) O(Nq×C2) O ( N k × C 2 ) O(N_k\times C^2) O(Nk×C2).
  2. 然后计算 A m q k ∝ exp ⁡ { z q T U m T V m x k C v } A_{mqk}\propto \exp\lbrace \frac{z_q^T U_m^T V_m x_k}{\sqrt{C_v}} \rbrace Amqkexp{Cv zqTUmTVmxk},复杂度是 O ( N q × N k × C ) O(N_q \times N_k \times C) O(Nq×Nk×C).
  3. A m q k A_{mqk} Amqk x k x_k xk相乘,计算复杂度是 O ( N q × N k × C ) O(N_q \times N_k \times C) O(Nq×Nk×C).
  4. 总体的计算复杂度就是 O ( N q × C 2 + N k × C 2 + N q N k C ) O(N_q\times C^2 + N_k\times C^2 + N_q N_k C) O(Nq×C2+Nk×C2+NqNkC).

在 DETR 中,Transformer encoder 的 query 和 key 元素就是特征图上的像素点,假设输入特征图的宽度和高度分别是 W W W H H H

  1. Encoder 中的 self-attention 的计算复杂度就是 O ( H 2 W 2 C ) O(H^2W^2C) O(H2W2C).
  2. Decoder 包括了 self attention 和 cross attention,输入包括来自于 encoder 的特征图、 N N N个 object queries。
  3. 在 decoder 的 cross attention 中,query 元素来自于 object queries,key 元素来自于 encoder 特征图,从 encoder 提供的特征图上提取 key 元素, N q = N , N k = H × W N_q=N, N_k=H\times W Nq=N,Nk=H×W,计算复杂度是 O ( N k C 2 + N N k C ) = O ( H W C 2 + N H W C 2 ) O(N_kC^2+NN_kC)=O(HWC^2+NHWC^2) O(NkC2+NNkC)=O(HWC2+NHWC2).
  4. 在 decoder 的 self attention 中,object queries 相互作用,query 和 key 元素都来自于 object queries。 N q = N k = N N_q=N_k=N Nq=Nk=N,复杂度就是 O ( 2 N C 2 + N 2 C ) O(2NC^2 + N^2C) O(2NC2+N2C).

引用

  • https://stackoverflow.com/questions/65703260/computational-complexity-of-self-attention-in-the-transformer-model
  • Deformable DETR

你可能感兴趣的:(注意力,transformer,算法,机器学习)