Multi Query Attention & Group Query Attention

Multi Query Attention(MQA)在2019年就被提出来了,用于推理加速,但在当时并没有受到很多关注,毕竟一张2080就能跑Bert-base了。随着LLM的大火,MQA所带来的收益得以放大。

思路

Multi Query Attention(MQA)跟Multi Head Attention(MHA)只有一词之差,但其思路非常简单,几乎跟MHA一致:

Multi Query Attention & Group Query Attention_第1张图片

MHA的Query、Key、Value分拆成8个头,每个头进行self-attention运算,而MQA是Query分成8个头,每个头共享一组Key和Value

MHA: Q, K, V = (512, 768), # seq_len, hidden_dim
			拆成8个头:
			Q : (8, 512, 96) 
			k, v: (8, 512, 96)
MQA: 
 Q -> (512, 768) 
 K -> (512, 96)
 v -> (512, 96)
把Q拆成8个头:
Q: (8, 512, 96)
K, V:(512, 96)

代码实现

  • MHA
...
self.Wqkv = nn.Linear( 
            d_model,
            d_model * 3,
            device=device,
        )
...

d_model * 3 拆成3个768维

  • MQA
...
self.Wqkv = nn.Linear( 
            d_model,
            d_model + 2 * self.head_dim,
            device=device,
        )
...

d_model + 2 * self.head_dim 拆成1个768维 + 2个96维

可以看到参数数量大幅减少。

实验结果

实验指标略微降低,但推理加速非常明显。

Multi Query Attention & Group Query Attention_第2张图片

Group Query Attention

Q拆分成8个头,K和V分别拆成4个头,然后对应进行attention运算。


参考

  • Fast Transformer Decoding: One Write-Head is All
    You Need
  • [LLM] multi query attention加速推理解码

你可能感兴趣的:(nlp,MQA)