SeqNet论文笔记

本文默认读者对GAN有基本的了解,对以下公式:

\underset{D}{max} \underset{ G }{min} E_{x \sim P_{data}(x)}[-logD(x)]+E_{z \sim P_{z}(z)}[-log(1-D(G(z)))](1)

了然于胸,其中D代表Discriminator,G代表Generator,P_{data} ( x )表示真实数据的密度函数,P_{z} ( z )一般为噪声的密度函数。GAN在模拟连续变量的分布中表现得不错,但无法直接应用于离散变量,因为Generator往往最终通过softmax函数输出一个关于所有离散点的概率向量,无法生成one-hot形式输出,足够好的D可以轻易的区分出合成数据和真实数据。而如果加入one-hot(argmax(*))这种函数,将导致不可导,使得G无法被训练,另外argmax函数并没有真实的模拟多项分布。

为解决以上问题,大神们提出了很多不同的方案,比如在《GANS for Sequences of discrete Element with the Gumbel-softmax Distribution》一文中,先阐述了使用Gumbel Max可以代替依照概率采样过程,而保留了带优化参数(重参数法),然后为解决不可导问题,将Gumbel Max 替代为 Gumbel Softmax ,引入退火策略,模拟Gumbel Max的效果。

而在《Sequence Generative Adversarial Nets with Policy Gradient》一文中,作者并没有直接的在公式(1)上优化G ,即先完整的生成文本序列X,再将序列送入G。而且是采用了强化学习的思想,结合D,计算每个输出文字的action reward,并使其得模型的回报的数学期望最大。这是什么意思?为什么这样做就能避免求导问题了呢?别急,先看完怎么做,再看为什么。以下,Y_{i:j}=( y_i ,...,y_j )表示G按顺序输出的i到j的文字序列,γ表示全体文字集合。Q^{G_\theta}_{D_\psi } ( S,a )表示在当前状态S下,采取行动a(下一个输出文字)的回报,其实这里的上角标应该是G_\beta,表示我们在用蒙特卡洛搜索时所采用的policy,但我们一般默认policy就是当前的G_\theta。那么第t个输出的回报期望为

J_\theta= E_{Y_{1:t-1} \sim G_\theta}[ \sum_{y_t \in \gamma}G_\theta ( y_t|Y_{1:t-1} )Q^{G_\theta} _{D_\psi}( Y_{1:t-1},y_t )]

关于参数求导得

\nabla J_\theta= E_{Y_{1:t-1} \sim G_\theta}[\sum_{y_t \in \gamma} \nabla_\theta G_\theta ( y_t|Y_{1:t-1} )Q^{G_\theta} _{D_\psi}( Y_{1:t-1},y_t )]= E_{Y_{1:t} \sim G_\theta}[\nabla_\theta log(G_\theta ( y_t|Y_{1:t-1} ))Q^{G_\theta} _{D_\psi}( Y_{1:t-1},y_t )](2)

E_{Y_{1:t} \sim G_\theta}中应该也有参数,为什么这里忽略了,原文中说在提供的资料里有更多的推导过程,在此暂不深究。同时注意到Q^{G_\theta} _{D_\psi}( Y_{1:t-1},y_t )项在求导时也应该被看做常数)

进而我们给出式(2)的无偏估计

\frac{1}{T} \sum_{t=1}^TE_{Y_{1:t} \sim G_\theta}[ \nabla_\theta log(G_\theta ( y_t|Y_{1:t-1} ))Q^{G_\theta} _{D_\psi}( Y_{1:t-1},y_t )(3)

其中 Y_{1:0}定义为S_0,即初始状态。

给出以上公式后,看明白的朋友会发现,(3)中并没有明确给出E[·]的采样方法。文中只是简要的说到” the expectation E[·] can be approximated by sampling methods”,别急

紧接着原文给出了大致的算法步骤,我们重点关注对G的训练的部分,可以看到训练过程大致分为两步:

1. Generate a sequence Y1:T = (y1,...,yT ) ∼ Gθ

2. Update generator parameters via policy gradient

第一步通过当前G_\theta生成一个序列(猜测这里采用了一些不可导的采样方法,比如Gumbel-max,或者直接依概率随机选择)。第二步,通过公式(3)计算导数,然后使用如adam 等方法优化。如此一来就看清楚了,为什么SeqNet可以解决离散GAN的求导问题,可以感性的理解为,传统的GAN,将simple(G_\theta)理解为一个关于参数θ函数,再送入D,反向求导训练,自然会遇到采样函数不可导的问题。而SeqNet把simple(G_\theta)先固化为常数(表现为E_{Y_{i:t}\sim G_\theta}的近似),再通过强化学习理论构造可微的待优化函数,进行求导训练。

你可能感兴趣的:(自然语言处理,GAN,自然语言处理,NLP,SeqGAN)