多头注意力机制(Multi-head Attention)及其在PyTorch中的使用方法分析

内容目录

  • 简介
  • 多头注意力机制
    • 注意力机制的直观理解
    • 注意力机制具体是怎么做的
  • PyTorch中的类
    • 伪代码重述
    • 具体怎么用

简介

多头注意力(Multi-head Attention)机制是当前大行其道的Transformer、BERT等模型中的核心组件,但我一直没懂其内部到底是怎么做的,PyTorch提供的接口的众多参数也弄不清有什么用。今天抽个时间,结合论文和PyTorch源码,深入学习一下。

仅为个人理解,如有错误敬请指正!

多头注意力机制

PyTorch中的Multi-head Attention可以表示为:

MultiheadAttention ( Q , K , V ) = Concat ( head 1 , ⋯   , head h ) W O \text{MultiheadAttention}(Q, K, V) = \text{Concat}(\text{head}_1, \cdots, \text{head}_h)W^O MultiheadAttention(Q,K,V)=Concat(head1,,headh)WO

其中

head i = Attention ( Q , K , V ) \text{head}_i=\text{Attention}(Q, K, V) headi=Attention(Q,K,V)

也就是说:Attention的每个头的运算,是对于输入的三个东西 Q , K , V Q, K, V Q,K,V进行一些运算;多头就是把每个头的输出拼起来,然后乘以一个矩阵 W O W^O WO进行线性变换,得到最终的输出。

注意力机制的直观理解

以最常见的文本编码来举例,假如我们要对下面这句话进行建模:

我买了一个西瓜,它很甜。

对于我们人来说,在阅读一段话时,我们会有自己关注的重点,这些重点往往包含更多更有用的信息,而其他的一些文字则相对来说没有那么重要。例如对于上面这个句子,假如我们关注买的是什么东西、它好不好吃,那么我们的重点就会关注于“西瓜”和“甜”的文字上。我们也希望编码器有这样的能力,将这段话输入到编码器中后,希望得到每个词的上下文语义的向量表示,这个向量也能够像我们人一样,更多地包含重要的信息,也就是将重要信息的权重增大,不重要的信息的权重减小,这就是Attention机制做的事情。

注意力机制具体是怎么做的

上文公式中, Q , K , V Q, K, V Q,K,V分别表示query、key和value。key把query作为依据,经过计算,得到对于每个key的attention权重,这些权重加起来等于1,最终结果为利用attention权重,对value进行加权求和。

Q , K , V Q, K, V Q,K,V三者再各应用中不一定相同,但是对于Transformer中的Self-Attention,三者取值都一样,就是取输入文本的embedding表示。

具体是怎么计算的,可以参考这篇文章的“1.3 Self-Attention”章节,说得很清楚。

PyTorch中的类

PyTorch近些版本已经提供了MultiheadAttention的类,可以直接调用,它的API请自行查看这里。但是它们的文档并没有写得很详细,导致具体参数不知道该怎样用。

伪代码重述

我分析了一下源码(基于PyTorch 1.6.0,其它版本可能大同小异),下面用不规范但方便理解的伪代码重述其具体实现,抽丝剥茧,可能个别细节跟源代码不一样,但计算过程和表述的内容是一样的。看完下面这些内容,就应该能对各参数的作用有大致的理解。

类的构造函数的参数:

  • embed_dim:int,即伪代码中的emb,Attention的输入内容的维度,也是内部各矩阵的纬度。
  • num_heads:int,头的数量。
  • dropout:float Optional,Dropout的比例,默认为0.0表示没有神经元单元会被置0。
  • bias:bool Optional,决定在对输入进行线性变换时是否加bias,默认是True。
  • add_bias_kv:bool Optional,如果设为True,则会在内部运算中在k、v的序列长度纬度增加一列随机向量,具体可以看下面的伪代码,默认为False。
  • add_zero_attn:bool Optional,则会在内部运算中在k、v的序列长度纬度增加一列零向量,具体可以看下面的伪代码,默认为False。
  • k_dim:int Optional,即伪代码中的emb_k,单独设置key的维度,默认为None时,取embed_dim。
  • v_dim:int Optional,即伪代码中的emb_v,单独设置value的维度,默认为None时,取embed_dim。

forward函数的参数:

  • query: torch.Tensor, (tgt_len, batch, emb)
  • key: torch.Tensor, (src_len, batch, emb_k),当k_dim为默认None时emb_k == emb
  • value: torch.Tensor, (src_len, batch, emb_v),当v_dim为默认None时emb_v == emb
  • key_padding_mask, (batch, src_len)
  • attn_mask: (tgt_len, src_len) 或者 (batch * num_heads, tgt_len, src_len)

输出:

  • attention_output: torch.Tensor, (tgt_len, batch, emb)
  • attention_output_weights: torch.Tensor, (batch, tgt_len, src_len)
# === 1. 将输入进行线性变换 ===
if not bias:
	b_q = None
	b_k = None
	b_v = None
q = linear(key, W_q, b_q)  # (tgt_len, batch, emb)
# 其中W_q和b_q分别是权重矩阵和bias向量,维度分别为(emb, emb), (emb,),如果bias设为False则没有bias
k = linear(key, W_k, b_k)  # (src_len, batch, emb)
# 其中W_q和b_q分别是权重矩阵和bias向量,维度分别为(emb, emb_k), (emb,),如果bias设为False则没有bias
v = linear(key, W_v, b_v)  # (src_len, batch, emb)
# 其中W_q和b_q分别是权重矩阵和bias向量,维度分别为(emb, emb_v), (emb,),如果bias设为False则没有bias


# === 2. 处理add_bias_kv ===
if add_bias_kv:
    src_len_1 = src_len + 1
    k = torch.cat((k, bb_k), dim=0)  # (src_len_1, batch, emb)
    # 其中bb_k是bias向量,维度为(emb,),在concat的时候repeat成(1, batch, emb)再拼接
    v = torch.cat((v, bb_v), dim=0)  # (src_len_1, batch, emb)
    # 其中bb_v是bias向量,维度为(emb,),在concat的时候repeat成(1, batch, emb)再拼接

    if attn_mask is not None:
        attn_mask = pad(attn_mask)  # 用0在最末维度的末尾补充一个数,变成(..., tgt_len, src_len_1)
    if key_padding_mask is not None:
        key_padding_mask = pad(key_padding_mask)  # 用0在最末维度的末尾补充一个数,变成(batch, src_len_1)

else:
    src_len_1 = src_len


# === 3. 处理add_zero_attn ===
if add_zero_attn:
    src_len_2 = src_len_1 + 1
    k = torch.cat((k, z_k), dim=0)  # (src_len_2, batch, emb)
    # 其中z_k是零向量,纬度为(1, batch, emb)
    v = torch.cat((v, z_v), dim=0)  # (src_len_2, batch, emb)
    # 其中z_v是零向量,纬度为(1, batch, emb)

    if attn_mask is not None:
        attn_mask = pad(attn_mask)  # 用0在最末维度的末尾补充一个数,变成(..., tgt_len, src_len_2)
    if key_padding_mask is not None:
        key_padding_mask = pad(key_padding_mask)  # 用0在最末维度的末尾补充一个数,变成(batch, src_len_2)

else:
    src_len_2 = src_len_1


# === 4. q、k相乘 ===
q = q.view(tgt_len, batch * num_heads, head_dim).transpose(0, 1)  # (batch * num_heads, tgt_len, head_dim)
# 将q转换形状,方便后续操作,其中head_dim * num_heads == emb
k = k.view(src_len_2, batch * num_heads, head_dim).transpose(0, 1)  # (batch * num_heads, src_len_2, head_dim)
# 将v转换形状,方便后续操作,其中head_dim * num_heads == emb
attn_output_weights = torch.bmm(q, k.transpose(1, 2))  # (batch * num_heads, tgt_len, src_len_2)
attn_output_weights = attn_output_weights * (head_dim ** -0.5)
# 将Attention值缩放,除以每一个头的维度的平方根


# === 5. 应用attn_mask ===
attn_mask = transform(attn_mask)  # 用unsqueeze、repeat等方式,将输入的attn_mask转换为(batch * num_heads, tgt_len, src_len_2)
if attn_mask.dtype == torch.bool:
    attn_output_weights.masked_fill_(attn_mask, float('-inf'))
else:
    attn_output_weights += attn_mask


# === 6. 应用key_padding_mask ===
attn_output_weights = attn_output_weights.view(batch, num_heads, tgt_len, src_len_2)  # (batch, num_heads, tgt_len, src_len_2)
key_padding_mask = transform(key_padding_mask)  # 用unsqueeze、repeat等方式,将输入的key_padding_mask转换为(batch, num_heads, tgt_len, src_len_2)
attn_output_weights = attn_output_weights.masked_fill(key_padding_mask, float('-inf'))
attn_output_weights = attn_output_weights.view(batch * num_heads, tgt_len, src_len_2)  # (batch * num_heads, tgt_len, src_len_2)


# === 7. Softmax ===
attn_output_weights = softmax(attn_output_weights, dim=-1)  # (batch * num_heads, tgt_len, src_len_2)


# === 8. Dropout ===
attn_output_weights = dropout(attn_output_weights)  # (batch * num_heads, tgt_len, src_len_2)


# === 9. 乘以v ===
v = v.view(src_len_2, batch * num_heads, head_dim).tranpose(0, 1)  # (batch * num_heads, src_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v)  # (batch * num_heads, tgt_len, head_dim)


# === 10. 计算输出 ===
attn_output = attn_output.transpose(0, 1).view(tgt_len, batch, embed)  # (tgt_len, batch, embed)
attn_output = linear(attn_output, W_o, b_o)  # (tgt_len, batch, embed)
# 其中W_o和b_o分别是权重矩阵和bias向量,维度分别为(emb, emb), (emb,)


# === 11. 输出 ===
attn_output_weights = attn_output_weights.view(batch, num_heads, tgt_len, src_len_2)  # (batch, num_heads, tgt_len, src_len_2)
attn_output_weights = attn_output_weights.sum(dim=1) / num_heads  # (batch, tgt_len, src_len_2)
return attn_output, attn_output_weights

具体怎么用

看完源代码,说说我对实际场景下,PyTorch的API到底该怎么用。(基于PyTorch 1.6.0,其他版本可能大同小异)

在实例化MultiheadAttention时:

  • embed_dim:query的维度,也作为Attention内部的纬度,key和value输入后都会转化为该维度。
  • num_heads:多头注意力机制的头数,要能整除embed_dim。
  • dropout:在计算得到attention权重后,会施以该dropout。
  • bias:模型参数加bias,默认True就行。
  • add_bias_kv:看上面的伪代码可以知道它的用处,但我不知道这样做的目的是什么,请求大神解释。
  • add_zero_attn:看上面的伪代码可以知道它的用处,但我不知道这样做的目的是什么,请求大神解释。
  • kdim:输入的key的维度,默认为None时取embed_dim。
  • vdim:输入的value的维度,默认为None时取embed_dim。

在使用时:

  • query:query向量,必须输入。
  • key:key向量,必须输入。
  • value:value向量,必须输入。
  • kay_padding_mask:用于mask掉序列中pad的位置,维度是(batch_size, src_len),建议输入的dtype为torch.bool,需要注意的是要mask掉的部分(即pad对应的位置)要设为True,否则设为False,跟往常相反
  • attn_mask:用于限制attention中每个位置能看到的内容,比如对于Transformer Decoder中,每一步只能看到它以及它前面的内容,那么这个矩阵就应该为一个下三角矩阵,对角线即对角线下方为False,其他部分为True(表示mask掉)。换一种说法描述:第 i i i行第 j j j列为True,表示在计算关于第 i i i个元素的Attention权重时,看不到第 j j j个元素的内容,即它在计算Attention权重时不被考虑。
  • need_weights:默认为True,返回的结果为(attn_output, attn_output_weights),其中attn_output_weights为多个head的平均权重。如果设为False,则返回的结果为(attn_output, None)。

完。

你可能感兴趣的:(Pytorch,深度学习,深度学习,算法,pytorch)