原始GAN训练难点之:原始目标函数没意义
其实,GAN 训练之难,更多的源于它 GAN 目标函数自身。GAN 的 D 的目标函数上文已经提过:
而G 的目标函数相应的是:
也就是说,G 需要最小化让 D 识别出自己生成的假样本的概率。但其实,在 GAN 原始论文[2] 中,作者就指出使用如上的 G 的目标函数会给训练造成问题。从形象化的角度来理解,在训练的早期,G 生成的假样本质量还非常差,与真实样本相距过远。这会知道 D 非常容易识别出 G 的假样本,从而使得 D 的训练几乎没有损失,也就没有有效的梯度信息回传给 G 让 G 去优化自己。这样的现象叫做 gradient vanishing,梯度消失问题。
换句话说,无论真实数据分布跟生成样本分布是远在天边,还是近在眼前,只要它们俩没有一点重叠或者重叠部分可忽略,JS散度就固定是常数,而这对于梯度下降方法意味着——梯度为0!此时对于最优判别器来说,生成器肯定是得不到一丁点梯度信息的;即使对于接近最优的判别器来说,生成器也有很大机会面临梯度消失的问题。
但是与不重叠或重叠部分可忽略的可能性有多大?不严谨的答案是:非常大。比较严谨的答案是:当与的支撑集(support)是高维空间中的低维流形(manifold)时,与重叠部分测度(measure)为0的概率为1。
从偏理论的角度来理解,梯度消失的问题实际上更“复杂”一些。想要理解它需要先理解 GAN 的 min-max game 的平衡条件。当 G 和 D 的对抗训练达到平衡时,可以认为取得了最优的 D(和最优的 G),此时最优的 D* 应该是两个分布的比值:
有了最优的 D* 的表达,就可以将它带入原始的 D 的目标函数,从而得到上页 slides 中的等价表达。也就是说,在 GAN 原始论文[2] 中就已经给出了,优化这样一个目标函数等价于优化 JS 散度(因为 2log2 是常数)。
然而,问题就出在了这个 JS 散度上。在论文[1] 中,作者指出当两个分布(比如这里的真实数据分布 P_r 和 生成数据分布 P_g 之间几乎不重合或者重合部分可忽略不计时,JS 散度也是个常数!而这在由神经网络拟合的分布中是非常常见的!也就是说,原始的 GAN 目标函数几乎是常数,所以也就不难理解为什么梯度几乎消失了。
Wasserstein GAN,简称 WGAN 基于的是 Wasserstein Distance,也叫 Earth-Mover Distance,推土机距离。这个距离可以形象的理解为,将一个分布变成另一个分布所需要的消耗。如果用一个直观的理解就是,把一个沙堆推到另一个地方,形成另一堆可能长得不一样的沙堆所需要的“距离”。而这种“转变”并不是唯一的,所以很可能有些“路径”消耗大,有些消耗小。用比较严谨的说法,Wasserstein 距离表示的是“最优规划路径”下的最短距离。Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。
这个距离有很多优良的性质,其中最最最重要的一条就是它可以在两个分布毫无重叠的情况下依然给出有效的度量。也就是说,用它作为优化目标则不需要担心梯度无意义或者梯度消失的问题。但是 Wasserstein Distance 中的“求下界”的操作无法准确高效计算,所以作者用了 Kantorovich-Rubinstein 对偶将其变换了一下,就得到了 WGAN 的目标函数:
但是这样的变换要求符合一个先决条件,也就是判别器 D 拟合的函数需要是 1-Lipschitz 连续函数。Lipschitz 连续实际上是要求一个连续函数的导函数的绝对值不大于某个常数,也就是说,它限制了一个连续函数的最大局部变动幅度。这个对于由神经网络来拟合的函数来讲,导数可以粗暴理解为神经网络的权重。所以在 WGAN[7] 中他们采取了 weight clipping,梯度剪裁的方式,将“导数”限制在 [-c,c] 范围内。也就是每次更新 D 的参数后,超过这个范围的都拉回来。
同时,他们还发现基于由于现在的 WGAN 中的 D 不再是做二分类任务,而是做一个“回顾”任务去拟合 Wasserstein distance,这个 distance 从实验上发现,与生成的图片质量呈负相关:
梯度裁剪的问题以及改进
虽然WGAN 在实验中展现了自己比原始 GAN 稳定的一面,但它依然遗留了一个问题。也就是通过 weight clipping,梯度裁剪这种方法选择的那个超参 c 对于实验结果的影响有多大?其实这个问题在原始 WGAN[7] 的论文中作者就有讨论。当 c 太大时,会出现梯度爆炸问题;过小也会导致梯度消失问题。
基于此,就有了 Improved WGAN[14] 这篇工作。他们首先分析了,到底基于梯度裁剪方式,会导致什么样的问题。首先在 Section 2.3 中他们证明,用这样的方式实际上也存在一个最优的判别器 D,当达到这个最优判别器时,D 的所有权重都会倾向于等于c,如下:
而这样的一个缺点就是,会导致学出来的网络过于简单,对于复杂函数的拟合能力或者说对于分布的建模能力会明显下降。比如下面的模拟实验中,就可以看到基于梯度裁剪的方法拟合分布,会忽略掉高阶动量。
为此,[14] 提出了一种新的方式去满足 WGAN 目标函数的额外要求,也就是 Lipschitz 连续性。他们指出既然最优的判别器的权重会倾向于一个常数,不如就把这个常数当成“目标”,把当前的权重与这个常数的距离,当成一种惩罚项或者正则项,加入WGAN 的目标函数中:
也因此,这个方法被叫做 gradient penalty,这样的 WGAN 就叫 WGAN-GP [14]。WGAN-GP 比原始 WGAN 的收敛速度更快,训练也更稳定,得到的生成结果的质量也更好。