本次要分享的论文是 A d v e r s a r i a l L e a r n i n g f o r N e u r a l D i a l o g u e G e n e r a t i o n Adversarial\ Learning\ for\ Neural\ Dialogue\ Generation Adversarial Learning for Neural Dialogue Generation,论文链接dialogue-gan,该论文所讲内容和上一篇分享的论文非常类似,都是用 G A N 、 R L GAN、RL GAN、RL 来做生成的,只不过本篇论文讲的是对话生成,有一些不一样的地方,因此再次分析总结下,就不细分析代码了。
###动机
首先还是来讲讲论文所说的动机,传统的 s e q 2 s e q seq2seq seq2seq 方法是以 M L E MLE MLE 作为目标函数,虽然在某些任务上取得了不错的成绩,但是也有一些显而易见的缺点:往往生成的句子是乏味的、通用的(低质量的)、短视的、重复性的。
一个好的模型,其生成的句子应该与人类生成句子真假难辨。因此论文采取了 G A N GAN GAN 的思想方法,但是传统的 G a n Gan Gan 又无法适用于离散的数据上,因此再采用 R L RL RL的方法,和上一篇分享的论文很像,判别器返回一个 r e w a r d reward reward 给生成器,指导其生成什么样的句子。
###模型
生成器:是一个带有 a t t e n t i o n attention attention 机制的 s e q 2 s e q seq2seq seq2seq 模型。
判别器:是一个 H i e r a c h i c a l N e u r a l N e t w o r k Hierachical\ Neural\ Network Hierachical Neural Network,其参考的论文链接hierarchical,以QA任务为例,一对 q u e r y a n s w e r query\ answer query answer 样本,如果该 a n s w e r answer answer 是人类生成的,则该样本标签为1,如果为生成器生成的则标签为0。将 q u e r y 、 a n s w e r query、answer query、answer 分别经过相互独立的 R N N RNN RNN ,得到两者的最后的 s t a t e state state ,然后做 c o n c a t concat concat 操作,作为 c o n t e x t _ i n p u t context\_input context_input,将该 c o n t e x t _ i n p u t context\_input context_input 喂给一个 R N N RNN RNN 做一个二分类的训练。嗯,就是这么简单。
Policy-Gradient-Training:
首先需要知道,判别器中返回的 r e w a r d reward reward 具体是什么信息?论文中是指 x , y {x, y} x,y (其中y 时生成器生成的)在判别器中被识别为真的概率值,也即是打分值作为 r e w a r d reward reward 回传给生成器。
这里需要了解在强化学习的几个重要概念中: s t a t e , a c t i o n , p o l i c y , r e w a r d state,action,policy,reward state,action,policy,reward , s t a t e state state 为现在已经生成的 t o k e n s tokens tokens , a c t i o n action action 是下一个即将生成的 t o k e n token token , p o l i c y policy policy 为 G A N GAN GAN 的生成器, r e w a r d reward reward 为GAN 的判别器所回传的信息。
由强化学习的知识可知,生成器的目标就是使得$maximize\ expected\ end\ reward $,论文中的公式:
J ( θ ) = E y ∼ p ( y ∣ x ) ( Q + ( x , y ) ∣ θ ) J(\theta) = E_{y\sim p(y|x)}(Q_{+}({x,y})|\theta) J(θ)=Ey∼p(y∣x)(Q+(x,y)∣θ)
这里要特别特别注意:上式中的 y y y 并不是 t r u e _ d a t a true\_data true_data,而是生成器生成的!!!也就是论文中所说的通过 p o l i c y s m a p l e policy\ smaple policy smaple 出来的。 只在在 p r e t r a i n pretrain pretrain 时,才用到 t r u e _ d a t a true\_data true_data。
这个 b b b 可视为一个 b a s e l i n e baseline baseline,可以这样理解,如果某一个 a c t i o n action action 的 r e w a r d reward reward 很大,则下次生成该 s e q u e n c e sequence sequence 的几率就会增大,但是如果 r e w a r d reward reward 都是正值呢?例如上面判别器给起打分,其分值都为正值,那么每个 s e q u e n c e sequence sequence 的 r e w a r d reward reward 都是正的,我们希望有个区分,对于 r e w a r d reward reward 较低的 s e q u e n c e sequence sequence 相对来说要抑制他的生成,故减去一个 b a s e l i n e baseline baseline ,使得 r e w a r d reward reward 有正有负。
同样的,上面只是对一个整句进行打分的,由上一篇论文的分析可知,对每步生成的 t o k e n token token 进行打分十分有必要的。
论文中对于 r e w a r d f o r e v e r y s t e p reward\ for\ every step reward for everystep 有两种解决办法:
论文中也倾向于使用蒙特卡洛搜索树方法,虽然比较耗时。
###Teacher Forcing
如果我们随机初始化生成器的话,可能存在以下问题:
一句话,就是生成器可能知道哪些 r e s p o n s e response response 是好的、坏的。但是并不知道怎么去生成好的、符合要求的句子,当遇到某些 t r a i n b a t c h train\ batch train batch 时,生成器生成的 r e s p o n e s respones respones 判别器很容易的就能判断出来,这就导致了生成器的 l o s s loss loss 突然变得很大。训练不够稳定,容易训飞了。
为了缓解上面所说的问题,论文中提出了 T e a c h e r F o r c i n g Teacher\ Forcing Teacher Forcing ,就是给出一些 r e a d _ d a t a read\_data read_data 来指导生成器,告诉生成器哪些 r e s p o n s e response response 是好的,让他学着去生成。其实就是用 s e q 2 s e q seq2seq seq2seq 里面的 M L E l o s s MLE\ loss MLE loss 去纠正生成器。嗯,就是这么简单。相当于上一篇论文中 p r e t r a i n pretrain pretrain 生成器。
###整体流程
###个人总结