GAN在seq2seq中的应用 Application to Sequence Generation

Improving Supervised Seq-to-seq Model
有监督的 seq2seq ,比如机器翻译、聊天机器人、语音辨识之类的 。
GAN在seq2seq中的应用 Application to Sequence Generation_第1张图片

 

而 generator 其实就是典型的 seq2seq model ,可以把 GAN 应用到这个任务中。

 

RL(human feedback)
训练目标是,最大化 expected reward。很大的不同是,并没有事先给定的 label,而是人类来判断,生成的 x 好还是不好。
GAN在seq2seq中的应用 Application to Sequence Generation_第2张图片
 
简单介绍一下 policy gradient。更新 encoder 和 generator 的参数来最大化 human 函数的输出。最外层对所有可能的输入 h 求和(weighted sum,因为不同的 h 有不同的采样概率);对一个给定的 h,对所有的可能的 x 求和(因为同样的 seq 输入可能会产生不一样的 seq 输出);求和项为 R(h, x)*P_θ (x | h) ,表示给定一个 h 产生 x 的概率以及对应得到的 reward(整项合起来看,就是 reward 的期望)

GAN在seq2seq中的应用 Application to Sequence Generation_第3张图片

 

用 sampling 后求平均来近似求期望:

GAN在seq2seq中的应用 Application to Sequence Generation_第4张图片

 

 

但是 R_θ 近似后并没有体现 θ(隐藏到 sampling 过程中去了),怎么算梯度?先对 P_θ (x | h) 求梯度,然后分子分母同乘 P_θ (x | h) ,而 grad(P_θ (x | h)) / P_θ (x | h) 就等于 grad(log P_θ (x | h)),所以就在 R_θ 原本的近似项上乘一个 grad(log P_θ (x | h))

GAN在seq2seq中的应用 Application to Sequence Generation_第5张图片

 

如果是 positive 的 reward(R(hi, xi) > 0), 更新 θ 后  P_θ (xi | hi) 会增加;反之会减小(所以最好人类给的 reward 是有正有负的)

GAN在seq2seq中的应用 Application to Sequence Generation_第6张图片

 

整个 implement 的过程就如下图所示,注意每次更新 θ 后,都要重新 sampling

GAN在seq2seq中的应用 Application to Sequence Generation_第7张图片

 

RL 的方法和之前所说的 seq2seq model (based on maximum likelihood)的区别

GAN在seq2seq中的应用 Application to Sequence Generation_第8张图片

 

GAN(discriminator feedback)
不再是人给 feedback,而是 discriminator 给 feedback。
GAN在seq2seq中的应用 Application to Sequence Generation_第9张图片

 

 

训练流程。训练 D 来分辨 pair 到底是来自于 chatbot 还是人类的对话;训练 G 来使得固定的 D 给来自 chatbot 的 (c', x~) 高分。

GAN在seq2seq中的应用 Application to Sequence Generation_第10张图片

 

 

仔细想一下,训练 G 的过程中是存在问题的,因为决定 LSTM 在每一个 time step 的 token 的时候实际上做了 sampling (或者取argmax),所以最后的 discriminator 的输出的梯度传不到 generator(不可微)。

GAN在seq2seq中的应用 Application to Sequence Generation_第11张图片

 

怎么解决?

  1. Gumbel-softmax https://casmls.github.io/general/2017/02/01/GumbelSoftmax.html

  首先需要可以采样,使得离散的概率分布有意义而不是只能取 argmax。对于 n 维概率向量 π,其对应的离散随机变量 x π 添加 Gumbel 噪声再采样。
  x π  = argmax(log(π i) + G i)
  其中,G 是独立同分布的标准 Gumbel 分布的随机变量,cdf 为 F(x) = exp(-exp(-x))。为了要可微,用 softmax 代替 argmax(因为 argmax 不可微,所以光滑地逼近),G 可以通过 Gumbel 分布求逆,从均匀分布中生成 G i = -log(-log(U i)),U i ~ U(0, 1) 
   GAN在seq2seq中的应用 Application to Sequence Generation_第12张图片 GAN在seq2seq中的应用 Application to Sequence Generation_第13张图片

 

 

  2. Continuous Input for Discriminator 

  避免 sampling 过程,直接把每一个 time step 的 word distribution 当作 discriminator 的输入。

  GAN在seq2seq中的应用 Application to Sequence Generation_第14张图片 

  这样做有问题吗?明显有,real sentence 的 word distribution 就是每个词 one-hot 的,而 generated sentence 的 word distribution 本质上就不会是 1-of-N,这样 discriminator 很容易就能分辨了,而且判断准则没有在考虑语义了(直接看是不是 one-hot 就行了)。

  GAN在seq2seq中的应用 Application to Sequence Generation_第15张图片

 

  3. Reinforcement Learning

  GAN在seq2seq中的应用 Application to Sequence Generation_第16张图片 

  把 discriminator 的 output 看作是 reward:

    • Update generator to increase discriminator = to get maximum reward    
    • Using the formulation of policy gradient, replace reward  R(c, x) with discriminator output D(c, x)
  
  和典型的 RL 不同的是,discriminator 参数是要 update 的,还是要输入给 discriminator 现在 chatbot 产生的对话和人类的对话,训练 discriminator 来分辨。
   GAN在seq2seq中的应用 Application to Sequence Generation_第17张图片

 

 

 
Unsupervised Seq-to-seq Model
 
Text Style Transfer
用 cycle GAN 来实现,训练两个 GAN,实现两个 domain 的互相转。仍旧要面对 generator 的输出要 sampling 的情况,选择上述第二种解决方案,就是连续化。直接用 word embedding 的向量。
GAN在seq2seq中的应用 Application to Sequence Generation_第18张图片

 

 

也可以用映射到 common space 的方法,sampling 后离散化的问题,可以用一个新的技巧解决:把 decoder LSTM 的 hidden layer 当作 discriminator 的输入,就是连续的了。

GAN在seq2seq中的应用 Application to Sequence Generation_第19张图片

 
 
Unsupervised Abstractive Summarization
 
Unsupervised Translation

 

你可能感兴趣的:(GAN在seq2seq中的应用 Application to Sequence Generation)