本文侧重于Pytorch中对self-attention的具体实践,具体原理不作大量说明,self-attention的具体结构请参照下图。
(图中为输出第二项attention output的情况,k与q为key、query的缩写)
本文中将使用Pytorch的torch.nn.MultiheadAttention来实现self-attention.
首先,前三个输入是最重要的部分query、key、value。由图1可知,我们self-attention的这三样东西其实是一样的,它们的形状都是:(L,N,E) 1。
L:输入sequence的长度(例如一个句子的长度)
N:批大小(例如一个批的句子个数)
E:词向量长度
输出的内容很少只有两项:
attn_output
即通过self-attention之后,从每一个词语位置输出来的attention。其形状为(L,N,E),是和输入的query它们形状一样的。因为毕竟只是给value乘了一个weight。
attn_output_weights
即attention weights,形状是(N,L,L),因为每一个单词和任意另一个单词之间都会产生一个weight,所以每一句句子的weight数量是L*L
这里对MultiheadAttention进行一个实例化并传入一些参数,实例化之后我们得到的东西我们就可以向它传入input了。
实例化时的代码:
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
其中,embed_dim是每一个单词本来的词向量长度;num_heads是我们MultiheadAttention的head的数量。
pytorch的MultiheadAttention应该使用的是Narrow self-attention机制,即,把embedding分割成num_heads份,每一份分别拿来做一下attention。
也就是说:单词1的第一份、单词2的第一份、单词3的第一份…会当成一个sequence,做一次我们图1所示的self-attention。
然后,单词1的第二份、单词2的第二份、单词3的第二份…也会做一次
直到单词1的第num_heads份、单词2的第num_heads份、单词3的第num_heads份…也做完self-attention
从每一份我们都会得到一个(L,N,E)形状的输出,我们把这些全部concat在一起,会得到一个(L,N,E*num_heads)的张量。
这时候,我们拿一个矩阵,把这个张量的维度变回(L,N,E)即可输出。
我们把我们刚才实例化好的multihead_attn拿来进行forward操作(即输入input得到output):
attn_output, attn_output_weights = multihead_attn(query, key, value)
mask可以理解成遮罩、面具,作用是帮助我们“遮挡”掉我们不需要的东西,即让被遮挡的东西不影响我们的attention过程。
在forward的时候,有两个mask参数可以设置:
key_padding_mask
每一个batch的每一个句子的长度一般是不可能完全相同的,所以我们会使用padding把一些空缺补上。而这里的这个key_padding_mask是用来“遮挡”这些padding的。
这个mask是二元(binary)的,也就是说,它是一个矩阵和我们key的大小是一样的,里面的值是1或0,我们先取得key中有padding的位置,然后把mask里相应位置的数字设置为1,这样attention就会把key相应的部分变为"-inf". (为什么变为-inf我们稍后再说)
attn_mask
这个mask经常是用来遮挡“正确答案”的:
假如你想要用这个模型每次预测下一个单词,我们每一个位置的attention输出是怎么得来的?是不是要看一遍整个序列,然后每一个单词都计算一个attention weight?那也就是说,你在预测第5个词的时候,你其实会看到整个序列,这样的话你在预测之前不就已经知道第5个单词是什么了,这就是作弊了。
我们不想让模型作弊,因为在真实使用这个模型去预测的时候,我们是没有整个序列的信息的。那么怎么办?那就让第5个单词的attention weight=0吧,即声明:我不想看这个单词,我的注意力一点也别分给它。
如何让这个weight=0:
我们先想象一下,我们目前拥有的attention scores是什么样的?(注:attention_score是attention_weight的初始样子,经过softmax之后会变成attention_weight.
attention_score和weight的形状是一样的,毕竟只有一个softmax的差别)
我们之前提到,attention weights的形状是L*L,因为每个单词两两之间都有一个weight。
如下图所示,我用蓝笔圈出的部分,就是“我想要预测x2”时,整个sequence的attention score情况。我用叉划掉的地方,是我们希望=0的位置,因为我们想让x2、x3、x4的权值为0,即:预测x2的时候,我们的注意力只能放在x1上。
对于其他行,你可以以此类推,发现我们需要一个三角形区域的attention weight=0, 这时候我们的attn_mask这时候就出场了,把这个mask做成三角形即可。
关于mask的题外话:有朋友好奇为什么有的地方看到的图mask了对角线有的没有,应该是因为sequence不同或者训练任务/方式不同,但本质上mask的原理是一样的。我再找一张图帮助大家理解,比如如果加上s(start)和e(end)的话就是类似这样:(图源2, 白色为mask掉的部分)
现在我们来说mask的值。和key_padding_mask不同,我们的attn_mask不是binary的,而是一个“additive mask”。
什么是additive mask呢?就是我们mask上设置的值,会被加到我们原本的attention score上。我们要让三角形区域的weight=0,我们这个三角mask设置什么值好呢?答案是-inf,(这个-inf在key_padding_mask的讲解中也出现了,这里就来说说为什么要用-inf)。我们上面提到了,attention score要经过一个softmax才变成attention_weights.
我们都知道softmax的式子可以表示为3:
σ ( z ) j = e z j ∑ k = 1 K e z k \sigma(z)_j = \frac{e^{z_j}}{\sum_{k=1}^{K}e^{z_k}} σ(z)j=∑k=1Kezkezj
(for j = 1,…,K)
当我们attention score的值设置为-inf (可以看作这里式子里的 z j = − inf z_j=-\inf zj=−inf),于是通过softmax之后我们的attention weight就会趋近于0了,这就是为什么我们这里的两个mask都要用到-inf。
Pytorch官方Documentation ↩︎
https://flashgene.com/archives/63944.html ↩︎
Wikipedia: Softmax函数 ↩︎