seq2seq模型是机器翻译中常见模型,由编码器(encoder) + 解码器(decoder)组成,其中编解码器都是由一层或者多层RNN组成。seq2seq模型的目标是将可变长度序列作为输入,并使用固定大小的模型将可变长度序 列作为输出返回。具体的实现并不难,可参看论文,官网也有详细的教程。
seq2seq 解码器的常见问题是,如果我们只依赖于上下文向量来编码整个输入序列的含义,那么我们很可能会丢失信息。尤其是在处理长输入序列时,这极大地限制了我们的解码器的能力。
Bahdanau et al 于2015年提出注意力(Attention)机制,即允许解码器关注输入序列的某些部分,而不是在每一步都使用完全固定的上下文,我们将它称为Local Attention。
Luong et al 2015年 提出了Global Attention 机制,改善了Bahdanau et al. 的基础工作。关键的区别在于: a. Global Attention考虑所有编码器的隐藏状态;
b. 通过Global Attention,我们仅使用当前步的解码器的隐藏状态来计算注意力权重 ;
这里重点介绍Global Attention的实现,以Pytorch为例。
观看上图的Attention Layer 模块,蓝色为编码器各个时间步,红色为解码器时间步。 具体操作为:
step1: 得到所有编码器时刻的隐藏状态的输出hs:维度为 [time_steps, hidden_size] ;
step2: 得到某个时刻的解码器的隐藏状态的输出ht:维度为[1, hidden_size] ;
step3: 通过某种评分函数score_f(), 即score_ti = score_f(ht, hs[i, :],) ,得到第ti 个时间步对应的score;
即ht 与编码器每个时间步的输出的隐藏状态进行 score_f 操作,得到维度为 [time_steps, 1] 的score_t;
step4: weight_score = softmax( score_t) ,进行归一化操作,得到每个时间步的权重。维度为:[timesteps, 1] ;
step5: 将weight_score 作用于 hs, 即对编码器的输出hs 做一个权重平均:得到 context vector,维度为:[1, hidden_size];如下图中的 c1/c2/c3
其实context 与ht 除了可以通过concat进行作用,也可以通过add 结合在一起。
上面介绍了Global Attention的方法步骤,其中step3中的评分函数的选取较为重要,可以通过以下三种方式来计算:
内积方向较为简单:
import torch
#这里先不考虑batch
time_step = 5 #时间步数,encoder阶段有多少个时间步长
hidden_size = 4 #隐藏层大小
en_output = torch.randn((time_step, hidden_size)) #encoder阶段所有的隐藏状态[5,4]
de_hidden = torch.randn((1, hidden_size)) #decoder阶段的某一个time_step(ti)的隐状态[1,4]
#将de_hidden转置,与en_output相乘,得到score; 即为解码器ti时刻的隐藏状态对应的在编码器的所有输出隐藏状态上的权重
score = torch.matmul(en_output, de_hidden.T) #[5,1]
#将该权重 softmax(归一化)
score = F.softmax(score,dim = 0) #[5,1]
#得到词向量context_vector
context_vector = torch.matmul(score.T, en_output) #[1,4]
与Dot相比,General就多了一个 Wa, 这个Wa 主要通过Linear层来实现。
import torch
#这里先不考虑batch
time_step = 5 #时间步数,encoder阶段有多少个时间步长
#一般情况下,两者的hidden_size一致
en_hidden_size = 4 #编码阶段的hidden_size
de_hidden_size = 3 #解码阶段的hidden_size
en_output = torch.randn((time_step, en_hidden_size)) #encoder阶段所有的隐藏状态[5,4]
de_hidden = torch.randn((1, de_hidden_size )) #decoder阶段的某一个time_step(ti)的隐状态[1,3]
atten = nn.Linear(en_hidden_size,de_hidden_size) #wa为 en_hidden_size --> de_hidden_size的之间的转换矩阵参数[en_hidden_size, de_hidden_size]= [4,3]
w = atten(en_output) #[5,3]
#得到 score; 即为解码器ti时刻的隐藏状态对应的在编码器的所有输出隐藏状态上的权重
score = torch.matmul(w, de_hidden.T) #[5,1]
#将该权重 softmax(归一化)
score = F.softmax(score,dim = 0) #[5,1]
#得到词向量context_vector
context_vector = torch.matmul(score.T, en_output) #[1,4]
import torch
#这里先不考虑batch
time_step = 5 #时间步数,encoder阶段有多少个时间步长
hidden_size = 4
en_output = torch.randn((time_step, hidden_size)) #encoder阶段所有的隐藏状态[5,4]
de_hidden = torch.randn((1, hidden_size)) #decoder阶段的某一个time_step(ti)的隐状态[1,4]
#
atten = torch.nn.Linear(hidden_size * 2, hidden_size) ##wa 为 hidden_size*2 --> hidden_size之间的转换矩阵参数[hidden_size*2, hidden_size] = [8,4]
#需要将v加入Parameter中去,以便参与梯度更新和参数学习
v = torch.nn.Parameter(torch.FloatTensor(hidden_size)).view(hidden_size, -1) #[4,1]
#即将de_hidden拼接到每个en_output的每个time_step的列中
concat_en_de = torch.zeros(time_step, hidden_size*2) #[5,8]
for i in range(time_step):
concat_en_de[i,:hidden_size] = en_output[i,:]
concat_en_de[i, hidden_size:] = de_hidden[0,:]
w = torch.tanh(atten(concat_en_de)) #[5,4]
score = torch.matmul(w, v) #[5,1]
#将该权重 softmax(归一化)
score = F.softmax(score,dim = 0) #[5,1]
context_vector = torch.matmul(score.T, en_output) #[1,4]
总结: 网上用的Global Attention多用前两种score方法。一般经验General方法好于Dot方法。通过Attention注意力机制给Decoder RNN加入额外信息,可以显著提高seq2seq的性能。
知乎文章:白裳 — 完全解析RNN, Seq2Seq, Attention注意力机制 讲到seq2seq训练问题,之前一直没有注意这一点。原文如下:
值得一提的是,在seq2seq结构中将Yt作为下一个时刻的输入Xt+1 <= Yt 进网络,那么某一时刻输出Yt错误就会导致后面全错。在训练时由于网络尚未收敛,这种蝴蝶效应格外明显。
为了解决这个问题,Google提出了大名鼎鼎的Scheduled Sampling(即在训练中按照一定概率选择输入Yt-1 或者 t-1 时刻对应的真实值,即标签,如下图),既能加快训练速度,也能提高训练精度。
谢谢前人的分享,受益匪浅!
之前在Deecamp 夏令营 AI 降水预测总结 这篇文章中试验了很多类似seq2seq的方法,但是训练的时候其实并没有注意到这种训练过程中产生的 蝴蝶效应 问题。在以后的工作中需要多加注意。
seq2seq论文:Sequence to Sequence Learning with Neural Networks
Local Attention 论文: NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE
Global Attention 论文:Effective Approaches to Attention-based Neural Machine Translation
Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks
Pytorch官网教程
完全解析RNN, Seq2Seq, Attention注意力机制
真正的完全图解Seq2Seq Attention模型