多头注意力(Multi-head Attention)机制是当前大行其道的Transformer、BERT等模型中的核心组件,但我一直没懂其内部到底是怎么做的,PyTorch提供的接口的众多参数也弄不清有什么用。今天抽个时间,结合论文和PyTorch源码,深入学习一下。
仅为个人理解,如有错误敬请指正!
PyTorch中的Multi-head Attention可以表示为:
MultiheadAttention ( Q , K , V ) = Concat ( head 1 , ⋯ , head h ) W O \text{MultiheadAttention}(Q, K, V) = \text{Concat}(\text{head}_1, \cdots, \text{head}_h)W^O MultiheadAttention(Q,K,V)=Concat(head1,⋯,headh)WO
其中
head i = Attention ( Q , K , V ) \text{head}_i=\text{Attention}(Q, K, V) headi=Attention(Q,K,V)
也就是说:Attention的每个头的运算,是对于输入的三个东西 Q , K , V Q, K, V Q,K,V进行一些运算;多头就是把每个头的输出拼起来,然后乘以一个矩阵 W O W^O WO进行线性变换,得到最终的输出。
以最常见的文本编码来举例,假如我们要对下面这句话进行建模:
我买了一个西瓜,它很甜。
对于我们人来说,在阅读一段话时,我们会有自己关注的重点,这些重点往往包含更多更有用的信息,而其他的一些文字则相对来说没有那么重要。例如对于上面这个句子,假如我们关注买的是什么东西、它好不好吃,那么我们的重点就会关注于“西瓜”和“甜”的文字上。我们也希望编码器有这样的能力,将这段话输入到编码器中后,希望得到每个词的上下文语义的向量表示,这个向量也能够像我们人一样,更多地包含重要的信息,也就是将重要信息的权重增大,不重要的信息的权重减小,这就是Attention机制做的事情。
上文公式中, Q , K , V Q, K, V Q,K,V分别表示query、key和value。key把query作为依据,经过计算,得到对于每个key的attention权重,这些权重加起来等于1,最终结果为利用attention权重,对value进行加权求和。
Q , K , V Q, K, V Q,K,V三者再各应用中不一定相同,但是对于Transformer中的Self-Attention,三者取值都一样,就是取输入文本的embedding表示。
具体是怎么计算的,可以参考这篇文章的“1.3 Self-Attention”章节,说得很清楚。
PyTorch近些版本已经提供了MultiheadAttention的类,可以直接调用,它的API请自行查看这里。但是它们的文档并没有写得很详细,导致具体参数不知道该怎样用。
我分析了一下源码(基于PyTorch 1.6.0,其它版本可能大同小异),下面用不规范但方便理解的伪代码重述其具体实现,抽丝剥茧,可能个别细节跟源代码不一样,但计算过程和表述的内容是一样的。看完下面这些内容,就应该能对各参数的作用有大致的理解。
类的构造函数的参数:
forward函数的参数:
输出:
# === 1. 将输入进行线性变换 ===
if not bias:
b_q = None
b_k = None
b_v = None
q = linear(key, W_q, b_q) # (tgt_len, batch, emb)
# 其中W_q和b_q分别是权重矩阵和bias向量,维度分别为(emb, emb), (emb,),如果bias设为False则没有bias
k = linear(key, W_k, b_k) # (src_len, batch, emb)
# 其中W_q和b_q分别是权重矩阵和bias向量,维度分别为(emb, emb_k), (emb,),如果bias设为False则没有bias
v = linear(key, W_v, b_v) # (src_len, batch, emb)
# 其中W_q和b_q分别是权重矩阵和bias向量,维度分别为(emb, emb_v), (emb,),如果bias设为False则没有bias
# === 2. 处理add_bias_kv ===
if add_bias_kv:
src_len_1 = src_len + 1
k = torch.cat((k, bb_k), dim=0) # (src_len_1, batch, emb)
# 其中bb_k是bias向量,维度为(emb,),在concat的时候repeat成(1, batch, emb)再拼接
v = torch.cat((v, bb_v), dim=0) # (src_len_1, batch, emb)
# 其中bb_v是bias向量,维度为(emb,),在concat的时候repeat成(1, batch, emb)再拼接
if attn_mask is not None:
attn_mask = pad(attn_mask) # 用0在最末维度的末尾补充一个数,变成(..., tgt_len, src_len_1)
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask) # 用0在最末维度的末尾补充一个数,变成(batch, src_len_1)
else:
src_len_1 = src_len
# === 3. 处理add_zero_attn ===
if add_zero_attn:
src_len_2 = src_len_1 + 1
k = torch.cat((k, z_k), dim=0) # (src_len_2, batch, emb)
# 其中z_k是零向量,纬度为(1, batch, emb)
v = torch.cat((v, z_v), dim=0) # (src_len_2, batch, emb)
# 其中z_v是零向量,纬度为(1, batch, emb)
if attn_mask is not None:
attn_mask = pad(attn_mask) # 用0在最末维度的末尾补充一个数,变成(..., tgt_len, src_len_2)
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask) # 用0在最末维度的末尾补充一个数,变成(batch, src_len_2)
else:
src_len_2 = src_len_1
# === 4. q、k相乘 ===
q = q.view(tgt_len, batch * num_heads, head_dim).transpose(0, 1) # (batch * num_heads, tgt_len, head_dim)
# 将q转换形状,方便后续操作,其中head_dim * num_heads == emb
k = k.view(src_len_2, batch * num_heads, head_dim).transpose(0, 1) # (batch * num_heads, src_len_2, head_dim)
# 将v转换形状,方便后续操作,其中head_dim * num_heads == emb
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # (batch * num_heads, tgt_len, src_len_2)
attn_output_weights = attn_output_weights * (head_dim ** -0.5)
# 将Attention值缩放,除以每一个头的维度的平方根
# === 5. 应用attn_mask ===
attn_mask = transform(attn_mask) # 用unsqueeze、repeat等方式,将输入的attn_mask转换为(batch * num_heads, tgt_len, src_len_2)
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
else:
attn_output_weights += attn_mask
# === 6. 应用key_padding_mask ===
attn_output_weights = attn_output_weights.view(batch, num_heads, tgt_len, src_len_2) # (batch, num_heads, tgt_len, src_len_2)
key_padding_mask = transform(key_padding_mask) # 用unsqueeze、repeat等方式,将输入的key_padding_mask转换为(batch, num_heads, tgt_len, src_len_2)
attn_output_weights = attn_output_weights.masked_fill(key_padding_mask, float('-inf'))
attn_output_weights = attn_output_weights.view(batch * num_heads, tgt_len, src_len_2) # (batch * num_heads, tgt_len, src_len_2)
# === 7. Softmax ===
attn_output_weights = softmax(attn_output_weights, dim=-1) # (batch * num_heads, tgt_len, src_len_2)
# === 8. Dropout ===
attn_output_weights = dropout(attn_output_weights) # (batch * num_heads, tgt_len, src_len_2)
# === 9. 乘以v ===
v = v.view(src_len_2, batch * num_heads, head_dim).tranpose(0, 1) # (batch * num_heads, src_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v) # (batch * num_heads, tgt_len, head_dim)
# === 10. 计算输出 ===
attn_output = attn_output.transpose(0, 1).view(tgt_len, batch, embed) # (tgt_len, batch, embed)
attn_output = linear(attn_output, W_o, b_o) # (tgt_len, batch, embed)
# 其中W_o和b_o分别是权重矩阵和bias向量,维度分别为(emb, emb), (emb,)
# === 11. 输出 ===
attn_output_weights = attn_output_weights.view(batch, num_heads, tgt_len, src_len_2) # (batch, num_heads, tgt_len, src_len_2)
attn_output_weights = attn_output_weights.sum(dim=1) / num_heads # (batch, tgt_len, src_len_2)
return attn_output, attn_output_weights
看完源代码,说说我对实际场景下,PyTorch的API到底该怎么用。(基于PyTorch 1.6.0,其他版本可能大同小异)
在实例化MultiheadAttention时:
在使用时:
完。