Transformer中的Mask的使用方法和注意事项

Masks in Transformer

    • transformer 入参
    • [src/tgt/memory]_mask
    • [src/tgt/memory]_key_padding_mask


前言

Transformer 这个架构中,在Encoder 和 Decoder 部分多次用到各种不同的mask,本文着重探讨Encoder 和 Decoder 入参中的 mask 的使用方法。


transformer 入参

首先,我们依照 PyTorch的实现为例子,看一下Transformer在forward过程中的几个入参:

forward(src, tgt, 
	src_mask=None, 
	tgt_mask=None, 
	memory_mask=None, 
	src_key_padding_mask=None, 
	tgt_key_padding_mask=None, 
	memory_key_padding_mask=None)

src 是 encoder 的输入序列。tgt 是 decoder 的输入序列。decoder 的输入包含 (encoder 编码的结果 + tgt)。
除了这两个,后面的入参都是 xxx_mask, 那么mask 是如何表示呢?有两种基本的方法,第一种是BoolTensor (True/False),第二种是 ByteTensor(0/1)。在定义中,为1/True的部分其实会被网络忽略。为0/False的部分会加入网络计算。
为了方便理解比较,我们统一按照Batch Frist的模式来规定输入,即 src, tgt 的shape分别是 (N, S, E), (N, T, E), N表示batch中样本数量,S/T 表示序列长度,E表示Embedding的维度,也可以类比成图像中的通道数。


[src/tgt/memory]_mask

(src/tgt/memory)_mask的shape定义分别为(N, S, S), (N, T, T), (N, T, S) 其实从shape后两维度的定义来看,我们就清楚,这几个mask和attention有关,因为self-attention 和 cross-attention 的shape与这些shape相同。memory 可以理解为Encoder的输出序列。
以一个 3x3 的 src 张量为例子:

[0 0 0]
[0 0 1]
[1 0 0]

那么 i = 0 时的张量,由 (0, 1, 2) 位置的张量计算而来。 i = 1 时的张量,由 (0, 1) 位置的张量计算而来。i = 3 时的张量,由 (2, 3) 位置的张量计算而来。
如何用矩阵运算实现上述操作呢?
假设得到的注意力张量为 attn, 同时得到 value 张量。

attn = attn.masked_fill(mask, 1e-6)
attn = F.softmax(attn, dim=-1)
output = torch.matmul(attn, value)

注意一个细节,attn 在被屏蔽掉的位置填充的是1e-6 而不是0。这时因为如果填充为0的话,可能会出现某个位置全0,在后续的数值运算上可能会不稳定。


[src/tgt/memory]_key_padding_mask

其实从字面意思上就可以理解这几个mask。
因为在整个batch训练的时候,序列可能长短不一。这时候需要做padding,把所有序列统一到一样的长度。所以有必要通过这几个mask标记长度。
src_key_padding_mask 的 shape 是 (N, S)。值为True/1的位置,表明该位置是padding的值,不属于原始序列。
tgt_key_padding_mask 的 shape 是 (N, T)。定义同上。
memory_key_padding_mask 的shape 和 src_key_padding_mask 相同,元素值表示的意义也一样。

你可能感兴趣的:(transformer,深度学习,人工智能)