目录
1. 简介
2. Seq2Seq
2.1 Encoder
2.2 Decoder
3. Seq2Seq with Attention
3.1 Decoder
4. Train
5. Decoding
5.1 理论
5.2 实例
6 总结
Seq2Seq的基本结构是encoder-decoder,这个模型的目标是生成一个完整的句子。这个模型曾经使得谷歌翻译有较大幅度的提升,下面就以机器翻译为例子,来描述详述这个模型。
注:学习此模型需要有LSTM深度学习模型相关基础。
Seq2Seq框架依赖于encoder-decoder。 encoder对输入序列进行编码,而decoder生成目标序列。
在encoder中输入hao are you ,每个单词,都被映射成一个维的词向量,在这个例子中,输入将被转化成,经过LSTM后,我们可以得到每一个词对应的隐状态,,和代表这个句子的向量,在这里,。
现在我们已经得到了代表句子的向量,这里我们将使用这个向量,输入到另一个LSTM单元,以特殊字符作为起时字符,得到目标序列。
当时间步等0时:
:Encoder输出的句子向量
:特殊词,代表起时位置,作为当前时间步骤的输入
:当前时间步骤的隐状态。,隐层的维度
:词表中,每个词的得分。,词表的大小
:函数(其实就是矩阵,w 和 b),
:经过归一化后得到在词表上的概率分布,,词表的大小
:中最大概率词的索引。int值。
当时间步等于1时:
与时间步等0不同的时,LSTM的输入
,隐状态的输入从e变成上一个时间步的隐状态
,词也变成上一个时间步预测的词。
一直到预测到了特殊字符,才停止。
上面的方法其实就是做了这么一个转换:
通常来说,seq2seq 加入attention机制后,会使得模型的能力所以提高。模型在解码阶段时可以关注对encoder序列的特定部分,而不是仅仅依赖于代表整个句子的向量。
加入attention机制后,encoder的过程不变,decoder过程发生相应的变化
:是上一个时间步的隐层输入。
:当前时间步的隐层输入,也是上一个时间步的输出。
:是context vec,叫做上下文向量,是对encoder的output求加权和的结果,,是LSTM隐层的维度
,,, 在2.1 已经做了说明,这里完全相同,下面看是怎么得到的
:是encoder是时步为的隐层;
:当前时间步骤隐层的输入;
:decoder当前时间步对encoder时间步为关注度的得分;
encoder每个时间步骤得分的向量
是进行softmax 归一化的后的值,
:在decoder时间步骤为时刻,对encoder的output求加权和的结果。
而对于函数,通常有以下几种选择,但是不限于以下三种,什么运算效果好,用什么运算。
回顾例子,目标是进行翻译,将“how are you” 翻译成 "comment vas tu"
如果在训练阶段,decoder的过程中,将t-1时间步预测的词,作为t时间步的输入词,很有可能在某一步预测错误,后面的序列将会全部乱掉,导致错误积累,并且使得模型无法在正确的输入分布中进行,会导致模型训练缓慢,甚至无法进行下去,为了加快处理速度。一个技巧是 输入token序列:,并且预测对应位置的下一个token。
decoder模型,每一个时间步的输出是词表上的一个概率,是词表的大小,对于给定的目标序列,,我们可以计算出整个句子的概率:
这里是指decoder第t和时间步上,生成第个单词的概率,我们要使得这个这个概率在目标序列上最大化,等价于使得:
最小化,我们定义式子18这个作为损失函数。
再具体的例子中,我们的目标就是最小化:
这里的损失函数其实就是交叉熵损失(Cross Entropy)
这里主要是说明解码过程,不是解码器
在解码的过程中,采用一种贪婪的模式,将上一步预测的最后可能的词,作为输入,传入到下一步。但是这种方法,一旦在一步发生错误,就可能会造成整个解码序列的错乱,为了尽可可能降低(目前并不能消除)这个风险,采用一种Beam Search的方法,我们的目标不是得到当前时间步上的最高的分,而是得到前个的最高得分。
那么对于在时间步上的解码假设集合一共组,下角标代表时间步,上角标代表top_k的第k个word。
那么是如何从在时刻得到候选集合呢?
注意:这里时从词表中选取的词汇,词表一共个词,因为这里将会是一个非常重要的点,与下一篇指针网络有所不同。
假设,假设,假设decoder一共就三个词可选,。
那么在一共有2种输出,在时,认为此时模型的,,将输入模型,得到的输出是的是。一共6个,即为候选集合:
从中挑选出时刻最高2个词,再从中挑出最高的2个
组成
然后后再从中挑选整个句子得分最高的个。得到
下面说明的计算方法:
目标,从中挑选个最大的句子。代表中第个句子的得分,。
这里之所以是要除以句子长度,是因为句子的长度会印象得分,我们以作为例子
如果不开3次方,那么句子的长度越大,连乘的概率越小,那么在做最终的预测时,模型预测出的结果将会偏向于预测较短的句子
得到后,知道预测到结束字符
从上至下,分别讲述了seq2seq模型的基本结构,和attention机制,并且介绍了这一类模型如何进行训练,如何进行生成。博客的内容如果不足之处,欢迎批评指正。
说明,这一篇博客,没有做项目代码的实现,因为下一篇博客,的指针生成网络会包含seq2seq+attention,并且有实现代码。以及结果的展示。
实现的博客链接(注:不是单纯的实现了seq2seq,是实现了一个基于seq2seq的模型,指针生网络,论文地址)