Deep Reinforcement Learning For Sequence to Sequence Models

这篇论文是一篇综述性质的文章吧,研究了现有的Seq2Seq模型的应用和不足,以及如何通过不同的强化学习方法解决不足,写的深入具体,mark一下。

本文的顺序是对文章的一个总结,并不是文章真实的组织顺序。

论文链接:https://www.paperweekly.site/papers/1973

代码链接:https://github.com/yaserkl/RLSeq2Seq

1、Seq2Seq模型

1. Seq2Seq模型简单回顾

Seq2Seq是解决序列问题的一种通用算法框架,在文章摘要、标题生成、对话系统、语音识别、图像转文本等领域都有广泛的应用,模型结构如下图所示:

训练时,Encoder的输入是训练数据中的真实序列,Decoder也是训练数据中的真实序列,Decoder阶段每个时刻t的输出经过softmax之后得到选择每个单词的概率,并选择交叉熵损失函数作为模型损失指导模型的训练,交叉熵损失函数如下图所示:

预测阶段,Decoder时刻t的输入是t-1时刻输出的概率最大的单词。这一阶段可以通过beam-search的方法选择一条最合适的序列作为输出。同时,选择ROUGE , BLEU, METEOR, CIDEr等方法,对模型的预测结果进行评价。

1.2 Seq2Seq模型的应用

文章总结了Seq2Seq的在各个领域的应用:

1.3 Seq2Seq模型存在的问题

文章中指出Seq2Seq模型存在以下两个问题:

exposure bias:这个可以简单理解为,一步错步步错,只要Decoder某个时刻的输出是错误的,那会导致后面整个序列都是错误的。

mismatch in training and evaluating:在训练阶段选择的是交叉熵损失进行模型的训练,在预测阶段,选择ROUGE等方法来评估模型,这就导致了mismatch的问题,即交叉熵损失最小的模型并不一定在ROUGE评估中效果最好,通过ROUGE等方法评估的最好的模型,并不一定能使交叉熵损失最小。

2、强化学习算法对于改进Seq2Seq模型的一些思路

2.1 Seq2Seq中的强化学习

在Seq2Seq方法中强化学习的几个关键要素是如下定义的:
State:强化学习发生在Decoder阶段,时刻t的State定义为前面已经选择的t-1个单词和当前模型的输入
Action:Action是根据某种策略选择一个单词作为时刻t的输出
Reward:奖励考虑立即的奖励和未来的奖励,这里的奖励可以理解为当生成完整个句子之后,通过ROUGE等评估方法得到的反馈。

强化学习中的几个重要的公式如下:

Q函数和状态价值函数

优势函数
为什么要引入优势函数,这里主要是淡化state的影响,我们不再去关心state本身的好坏,而只关心在当前state下选取的action的好坏。优势函数计算如下:

优势函数刻画了当前状态下,每个动作对这个特定状态的好坏。V(s)是所有动作的奖励的期望值,因此优势函数大于0的动作所能获得的奖励高于期望值,而优势函数小于0的动作所能获得的奖励值小于期望值。

强化学习的目标

强化学习的目标可以是下面三个中的任意一个:

即最大化期望奖励,最大化每个时刻的优势函数,最大化每个时刻的Q值。

2.2 Policy Gradient

策略梯度的方法的优化目标是我们刚才提到的第一个,即最大化期望奖励,那么结合Seq2Seq是怎么做的呢?将文中提到的一下碎片化的东西整理一下,可以得到以下的过程:

Pre-train模型
为了提高模型的收敛速度,我们需要首先预训练模型,预训练模型使用交叉熵损失函数。

每次训练采样N个序列
在每次训练的时候,基于当前的模型采样得到N个完整的序列。采样基于Decoder每个时刻t的输出经softmax后的结果,并作为下一时刻的输入。原文中一开始说每次训练指采样一个序列,但是这对模型来说Variance非常大,因为不同的序列得到的Reward差别很大,模型的方差自然也很大。

通过ROUGE等方法得到reward并训练
为了保证训练和预测时模型的一致性,我们通过ROUGE等方法得到这批序列的reward,并使用如下的损失函数进行模型的训练:

上面的损失函数的意思即,我们希望加大能够得到更大reward的词出现的概率,减小得到较低reward的词出现的概率。

另一个需要注意的点,文章中提到,减去一个rb(rb一般通过这一批Sample的平均reward得到)可以减小模型的方差Variance,是否减去这个rb是不会影响期望的损失的。为什么会减小Variance呢?这就好比优势函数的概念,我们忽略state本身好坏的影响,而只关注在这个state下,action的好坏。

Policy Gradient算法缺点
使用该方法可以得到一个无偏的反馈,但是需要采样完整的序列才可以得到奖励,因此收敛速度可能非常慢,因此考虑时间差分的方法,即在每一步采样之后都能够得到一个反馈,可以考虑Actor-Critic方法和DQN等方法。

2.3 Actor-Critic Model

使用Actor-Critic的最大化目标是优势函数最大化,即:

使用AC方法无需进行一个完整序列的采样,可以在每一步通过Critic来估计我们可能得到的奖励。上面的式子中,我们期望的是优势函数最大化,优势函数计算如下:

可以近似的认为是下面的式子:

因此,我们的Critic估计的其实是一个状态的价值,即V(s)。Critic对状态价值V(s)的估计通过监督学习的方式进行训练:

vi是通过采样以及ROGUE评估的到的当前状态的真实价值,即:

因此,Actor-Critic的过程如下所示:

使用AC的方法,我们可以减小模型的方差,因为我们是在每一步都有一个奖励的反馈,但是这个反馈是有偏,因为我们没有观测一个完整的序列而得到奖励。

当模型收敛后,我们认为我们的Actor已经是最优的策略了。那么在预测阶段,我们就可以采用beam-search或者是greedy的方法,根据Encoder的输入来得到Decoder的输出。

2.4 Actor-Critic with Q-Learning

使用该方法的最优化目标是Q值最大化,通过Q值的反馈来更新Actor的策略,即:

此时Critic预估的是Q值,其通过下面的式子进行训练:

式子中第一项是q-eval值,即预估的Q值,第二项是q-target,即真实的Q值,我们希望二者越相近越好。真实的Q值如何计算呢,在上图式子中的第二行给出了答案,即立即获得的奖励,加上下一个状态所能获得的最大奖励的折现。

可以看到,基于Q函数和上面的AC方法有两点的不同。首先,基于Q函数只需要得到当前的立即奖励,而另一种方法需要完整的采样,然后计算多轮奖励的折现和,作为状态价值V(s)。另外,基于Q函数的Critic估计的是状态-动作对的Q值,而另一种方法估计的是状态价值V(s)

Deep-Q-learning的具体过程如下:

3、其他注意的点

3.1 Deep-Q-learning的改进

文中主要提到了双网络结构Double-DQN,加入优势函数的Dueling-DQN,以及通过经验池来减少数据中相关性等等。

不同的强化学习方法有缺点如下:

3.2 RL+Seq2Seq的应用

文中总结了不同应用的Seq2Seq和RL的结合,其状态,动作以及价值分别是什么。如下表所示:

本文只是对文章内容的一个简单理解,大家如果感兴趣的话,可以阅读原文,也欢迎大家与我讨论!

你可能感兴趣的:(Deep Reinforcement Learning For Sequence to Sequence Models)