最近在做文本生成任务,例如诗歌生成,问题生成,摘要生成等,使用了Bart模型,CPT模型,mt5模型,t5模型等。生成模型是基于Seq-to-Seq(Encoder-Decoder)结构,输入的文本经过Encoder编码得到向量再输入到Decoder解码生成一个文本。Encoder和Decoder会使用多层transformer 以及self-attention,需要注意Encoder和Decoder中的attention mask使用,在Decoder中的self-attention当前时刻t的词只能关注到时刻0到t-1的词,无法关注到时刻t+1的词,是因为使用了在计算self-attention的时候加入了矩阵MASK M M M的右上角被mask为 − ∞ -\infty −∞,从而使得Decoder进行更好的生成。Decoder中的MASK M M M矩阵是由UNIfied pre-trained Language Model(UniLM)提出来的,对应于UniLM中的Seq-to-Seq LM。下面介绍Seq-to-Seq LM原理以及在bart中的使用。
Seq-to-Seq LM是UniLM提出来的,其中UniLM中有三种LM方式分别为Unidirectional LM,Bidirectional LM和Sequence-to-Sequence LM。如下图1所示。Unidirectional LM包括left-to-right and right-to-left LM,例如left-to-right LM当前词计算attention的时候只能使用当前词前面的词,后面的词无法使用到。
在 Seq-to-Seq LM中的self-attention计算中加入了一个MASK 矩阵 M M M,这个MASK 矩阵 M M M的右上角的元素是 − ∞ -\infty −∞,左下角的元素为0。
1. Multi-Layer Transformer
输入向量 { x i } i = 1 ∣ x ∣ \{x_{i}\}_{i=1}^{|x|} {xi}i=1∣x∣,第0层的transformer的输出状态 H 0 H^{0} H0记为 H 0 = [ x 1 , x 2 , … , x ∣ x ∣ ] H^{0} = [x_{1}, x_{2}, \dots, x_{|x|}] H0=[x1,x2,…,x∣x∣]。 L L L层的Transformer的结果记为 H l = T r a n s f o r m e r l H l − 1 H^{l} = Transformer_{l} H^{l-1} Hl=TransformerlHl−1, l ∈ [ 1 , L ] l\in[1, L] l∈[1,L]。 l l l层的self-attention计算过程如下:
Q = H l − 1 W l Q K = H l − 1 W l K V = H l − 1 W l V M i j = { 0 , a l l o w t o a t t e n d − ∞ , p r e v e n t f r o m a t t e n d i n g A l = s o f t m a x ( Q K T d k + M ) V Q = H^{l -1}W_{l}^{Q} \\ K = H^{l -1}W_{l}^{K} \\ V = H^{l -1}W_{l}^{V} \\ M_{ij}=\left\{ \begin{aligned} 0, & & allow to attend \\ -\infty, & & prevent from attending \end{aligned} \right. \\ A_{l} = softmax(\frac{QK^{T}}{\sqrt{d_{k}}} + M)V Q=Hl−1WlQK=Hl−1WlKV=Hl−1WlVMij={0,−∞,allowtoattendpreventfromattendingAl=softmax(dkQKT+M)V
其中 H l − 1 ∈ R ∣ x ∣ × d h H^{l-1}\in R^{|x|\times d_{h}} Hl−1∈R∣x∣×dh, W l Q , W l K , W l V ∈ R d h × d k W_{l}^{Q},W_{l}^{K},W_{l}^{V}\in R^{d_{h}\times d_{k}} WlQ,WlK,WlV∈Rdh×dk, M ∈ R ∣ x ∣ × ∣ x ∣ M\in R^{|x|\times|x|} M∈R∣x∣×∣x∣
已UniLM的输入句子 S 1 S_{1} S1为[SOS] t1 t2 [EOS], 输出的句子 S 2 S_{2} S2为t3 t4 t5 [EOS],将 S 1 S_{1} S1和 S 2 S_{2} S2拼接后得到句子[SOS] t1 t2 [EOS] t3 t4 t5 [EOS]输入模型,MASK M M M如下:
上图中矩阵MASK M M M右上角阴部部分的元素均为 − ∞ -\infty −∞,在self-attention中的softmax的时候加上MASK M M M使得self-attention关注不到 S 2 S_{2} S2的句子,使得模型有更好的生成能力。下面介绍MASK M M M 在bart中的使用。
1. Bart中的self-attention
Bart中的self-attention计算和上面介绍的UniLM中的attention计算方式一样,在attention计算softmax的时候加入了MASK M M M矩阵,在代码中用attention mask代替,代码如下:
2. 在Bart中的Encoder中的MASK M M M
在Encoder中的MASK M M M 是根据attention mask 计算得到的,attention mask大小为 [ b a t c h s i z e , s e q l e n g t h ] [batch_size, seq_length] [batchsize,seqlength]有1和0构成。在encoder中的attention mask矩阵进行了扩维变为 [ b a t c h s i z e , 1 , s e q l e n g t h , s e q l e n g t h ] [batch_size, 1, seq_length, seq_length] [batchsize,1,seqlength,seqlength],把attention mask中的元素为0的变为一个很大的负数,元素原来是1的位置处的元素为0,避免attention mask中元素为0的位置处的padding token对句子间的词之间的相关联程度的影响,相关代码如下:
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
结果如下:
3. Bart中的Decoder中的MASK M M M
在Encoder中的MASK M M M 是根据attention mask 计算得到的,attention mask大小为 [ b a t c h s i z e , s e q l e n g t h ] [batch_size, seq_length] [batchsize,seqlength]有1和0构成。在encoder中的attention mask矩阵进行了扩维变为 [ b a t c h s i z e , 1 , s e q l e n g t h , s e q l e n g t h ] [batch_size, 1, seq_length, seq_length] [batchsize,1,seqlength,seqlength],在把attention mask中的元素为0设置为 ∞ \infty ∞或者为一个负无穷大的数,相关代码如下
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
decoder中 MASK M M M 结果如下:
在生成模型中decoder中的self-attention计算加入MASK M M M矩阵,使得模型具有更好的生成能力。如有错误,欢迎指证。