基于深度学习Seq2Seq框架的技术总结

随着互联网经济的普及定位技术的快速发展,人们在日常生活中产生了大量的轨迹数据,例如出租车的GPS数据、快递配送员PDA产生的轨迹数据等。轨迹数据是一种典型的时空数据(Spatial-Temporal Data),是按照时间顺序索引且空间变化的一系列数据点。在时空数据的数据挖掘中,我们也会大量借鉴在自然语言处理等时序数据中发展很成熟的技术。

本篇文章为您带来的是Seq2Seq(Sequence to Sequence)模型的技术总结。将着重介绍三个里程碑式的方法,Sequence to SequenceLearning with Neural Networks、Learning Phrase Presentations using RNNEncoder-Decoder for Statistical Machine Translation、Neural MachineTranslation by Jointly Learning to Align and Translate。

一、Sequenceto Sequence Learning with Neural Networks

本章我们将会介绍Sutskever I.等人于2014年发表在NeurIPS的一篇论文,目前引用量已经超过12000次。最常见的Seq2Seq模型是解码器-编码器(Encoder-Decoder)模型,由于时序数据的序列性质,通常情况下,我们使用RNN(Recurrent Neural Network)在Encoder中得到输入序列的特征向量,再将此特征向量输入Decoder中的另一个RNN模型,逐一生成目标序列的每一个点,图1展示了该模型的基本框架,以德文翻译至英文为例:

基于深度学习Seq2Seq框架的技术总结_第1张图片

图1 Seq2Seq 模型框架

1.Encoder

输入是一句德文的序列,“guten morgen”,输出是翻译的英文序列“good morning”。为了帮助模型区分句子的开头与结尾,我们在每一个句子(序列)的首尾分别添加(start ofsequence)和(end ofsequence) 做标记,由于文本形式的数据无法输入深度学习模型,我们首先将文字通过字典做独热编码,然后把序列输入至嵌入层(黄色部分)和基于RNN的Encoder层(绿色部分)。在每一步中,基于RNN的Encoder的输入由两部分组成,其一是经过嵌入层的单词,其二是RNN模型前一步的隐藏状态,而后输出新的隐藏状态,公式表达如下:

在这里插入图片描述

此处,我们使用通用的RNN框架,在实际操作中可以被替换为LSTM(LongShort-Term Memory)或GRU(Gated RecurrentUnit)等。这里,输入序列可视为,其中,初始化的隐藏状态为0。当我们把序列中所有的文字都经过EncoderRNN处理过后,会得到一个最终的隐藏状态,在图1中,向量代表了整个源输入序列的特征,此向量也会作为DecoderRNN的初始隐藏状态。

2.Decoder

类似于源序列在Encoder中的过程,在每一步,DecoderRNN(蓝色部分)的输入也由两个部分构成,其一隐藏状态,为了方便与EncoderRNN区分,我们用来表示;其二经过Decoder的嵌入层d的单词,公式表达类似于EncoderRNN:

在这里插入图片描述

通过RNN,我们可以得到目标序列的隐藏状态,然而这并不是真正的单词,因此我们需要在此基础上使用一个全连接层(紫色部分)去预测最有可能的单词:

在这里插入图片描述

3.Seq2Seq

我们在DecoderRNN中每一次会得到一个单词,这个预测出来的单词会作为下一步的输入,循环直到目标序列生成完毕。事实上,我们并不会在Decoder的每一步都是用预测出来的单词,因为这样做会出现一步错,步步错的情况。为了提高模型的准确率,我们会随机使用真实的单词作为下一步的输入,这个概率称为teacher forcing ratio。当我们需要验证模型的有效性性时,将teacher forcing ratio设置为0即可,这样就可以保证不出现基准真值泄露的问题。

以上,通过EncoderRNN和DecoderRNN,我们可以得到预测的序列在这里插入图片描述
我们会把预测的序列在这里插入图片描述
与基准真值的序列做比对此计算误差,更新参数来不断的训练模型。


二、Learning Phrase Representation using RNN Encoder-Decoder for Statistical Machine Translation

上一章为大家简单介绍了自然语言处理中最基础的Seq2Seq模型,本章我们将介绍第二种略微复杂一些的基于短语学习的Seq2Seq模型,本章的模型是由Cho, K.等人[2]提出的,目前的引用量超过了11000次。图2展示了模型的基本框架,同前一章类似,我们仍以德文翻译至英文为例。

基于深度学习Seq2Seq框架的技术总结_第2张图片

图2 基于短语学习的Seq2Seq模型框架

1.Encoder

Encoder的实现与第一章没有特别大的区别,除了基础的RNN,LSTM以及GRU都可以作为选择,LSTM与GRU在性能上并没有绝对的优劣之分,需要根据不同的需求做选择。现阶段Pytorch, Tensorflow,Keras发展都非常成熟,内置了非常多常见的框架及函数,用户无需重头搭建,可根据自身需求调节参数以达到最优效果。

2.Decoder

本章的Decoder与前一章有较大区别。在前一章中我们提到,源序列经过Encoder后会得到一个最终的隐藏状态,该隐藏状态z包含了所有源序列的信息,并会成为Decoder生成目标序列的初始隐藏状态。从图1我们可以明显看到,z只被使用了一次,后续的隐藏状态都由Decoder的前一步生成。直觉上来讲,这样的做法没有充分利用源序列的信息。因此,Cho, K.等人做了以下改进,在每一步DecoderRNN输入层及全连接预测单词层加入z,同时,在全连接层预测下一个单词时,不仅加入了Encoder的最终隐藏状态z,还加入了当前单词经过嵌入层之后的结果,隐藏状态和预测值的公式被更新为:

在这里插入图片描述

3.Seq2Seq

Seq2Seq部分的逻辑与前一章相似,Decoder每次预测下一步的单词,由Seq2Seq循环生成目标序列,此部分不再做过多赘述。

三、Neural Machine Translation by Jointly Learning to Align and Translate

前面两章我们为大家介绍了两种Seq2Seq的模型,虽然在第二种模型里,可以很好的提取源序列的信息,并用在每一个目标序列的输出上,但无法避免另一个问题。无论是自然语言里的句子,还是轨迹数据,这些序列中的每一个点更多情况下是受周围或者部分其他点的影响,而不是整个序列。举例来说,一条长度为10公里的轨迹,车辆行驶速度更大概率是受当前位置前后一公里整体行驶速度的影响,而不是更遥远的地方。

因此,在生成目标序列时,更好的办法不是在每一步加入之前源序列的全部信息,而是只关注部分信息。Bahdanau, D.等人同样也在2014年发表一篇影响力深远的论文,他们的亮点是首次在Seq2Seq模型中加入了Attention思想,目前引用量已经超过15000次。

本章中涉及到EncoderAttention Decoder Seq2Seq,其中Encoder与Seq2Seq与前面两章无异,本章将只介绍Attention模型在Decoder中的应用。

基于深度学习Seq2Seq框架的技术总结_第3张图片

图3 Seq2Seq(Attention) 模型框架

1.Attention

对于Attention层来说,通常情况下,我们会计算一个attention向量,它的长度与源序列数据长度一致。在attention向量中,每一个元素的取值范围是0到1,整个向量之和是1。我们会将attention向量与源序列的隐藏状态相乘,得到一个权重向量,公式表达如下:
在这里插入图片描述

我们会在Decoder中的每一步重新计算权重向量,并把这个权重向量用在DecoderRNN的输入以及全连接层的输入,如图3所示(橘色)。

2.Decoder

Decoder层的实现逻辑与第二章很相似,区别在于把原来的替换成了权重向量w,这样就可以保证目标序列在Decoder中不需要关注源序列的全部信息,而是专注于与自身相关的信息。加入Attention,不仅能帮助Decoder在上生成序列时关注有效信息,还可以减少信息过多带来的冗余,一定程度上可以减少参数的更新。

以上就是本篇文章关于Seq2Seq模型的总结,从最基础的Seq2Seq模型,到基于短语学习的Seq2Seq模型,再到基于Attention的Seq2Seq,每一个方法都在不同阶段解决了一定的问题,经过时间的证明,这三篇论文在Seq2Seq领域都有非常大的影响力,是经典之作。
Seq2Seq不仅可以运用在自然语言处理领域,时空数据领域也有会有很多应用,未来JUST也会在这一方面做进一步的挖掘和研究。


往期好文推荐

只看这三点就够:快速了解联邦学习框架!
京东数科七层负载 | HTTPS硬件加速 (Freescale加速卡篇)
京东数科mPaaS:深度解读京东金融App(Android)的秒开优化实践


获取更多技术干货&独家福利,请关注“京东数科技术说”微信公众号
基于深度学习Seq2Seq框架的技术总结_第4张图片

你可能感兴趣的:(时空数据JUST,Seq2Seq,深度学习,自然语言处理,时空数据)