GAN:WGAN-GP-带有梯度惩罚的WGAN

论文:https://arxiv.org/pdf/1704.00028.pdf

代码:GitHub - igul222/improved_wgan_training: Code for reproducing experiments in "Improved Training of Wasserstein GANs"

发表:2017

GAN:WGAN-GP-带有梯度惩罚的WGAN_第1张图片

WGAN三部曲的终章-WGAN-GP

摘要

WGAN在稳定训练GANs方面有一定的进展,但依然存在生成样本质量低、难以收敛等问题。主要原因是:采用了weight clipping。本文作者提出了gradient penalty (GP)来替代 w-c,有效的解决了WGAN存在的缺陷。同时本文也是第一个在很深的网络上(res101)成功训练GANS.

GAN:WGAN-GP-带有梯度惩罚的WGAN_第2张图片

weight clipping缺陷:模型建模能力弱化,以及梯度爆炸或消失。

GAN:WGAN-GP-带有梯度惩罚的WGAN_第3张图片

权重约束的难点

作者发现WGAN中的权重裁剪会导致优化困难,即使优化成功,也可能导致判别器具有病态的值表面。作者尝试了其他的权重约束方案:L2 norm clipping、weight normlization、以及L1和L2 权重衰减,都存在相似的问题,并不能解决问题

作者同时发现在WGAN中:判别器中增加BN可以一定程度上缓解上述问题,但随着网络的加深,WGAN依然会面临难以收敛的困境。

权重分布问题

WGAN在训练过程中保证判别器的所有参数处于[-c, +c]的范围内,约束了判别器对相似样本有相似的结果。实际训练需求是希望判别器尽可能拉开真假样本的分数差,而weight-clipping限制了网络的参数范围,使得最优的策略是尽可能让所有参数拉开,要么取最大值c,要么取最小值-c。而g-p 的权重数值分布就比较正常。

GAN:WGAN-GP-带有梯度惩罚的WGAN_第4张图片GAN:WGAN-GP-带有梯度惩罚的WGAN_第5张图片

GAN:WGAN-GP-带有梯度惩罚的WGAN_第6张图片

梯度回传问题

c-p另一个问题就是会导致梯度消失或者爆炸,如下图。判别器通常是一个多层网络,设想一下:

如果weight clipping 阈值设置的很小(比如下图中的c=0.001),每经过一层网络,保留的梯度就变小一点,多层之后,可能就会出现梯度消失的问题。

如果weight clipping 阈值设置的很大(比如下图中的c=0.1),每经过一层网络,保留的梯度就变大一点,多层之后,可能就会出现梯度爆炸的问题。

所以只有设置的不大不小,比如c=0.01(wgan作者推荐的数值),下图中的紫色线,梯度保持相对合理,才能让生成器获得不错的回传梯度。所以这个参数在实际应用中调试不容易把握。

本文提出的 g-p(图中蓝色线),不论判别器深度如何,梯度范数,都保持相对稳定,有效解决梯度消失和梯度爆炸的问题。

GAN:WGAN-GP-带有梯度惩罚的WGAN_第7张图片

GAN:WGAN-GP-带有梯度惩罚的WGAN_第8张图片

梯度惩罚

在原始判别器的损失上增加了一项惩罚,惩罚系数设置为10经过验证,可以在各个框架和数据集上表现不错。

公式在下面, 里面表达的是它在WGAN的loss上加了一个惩罚项,如果判别器的 gradient 的 norm,离 1 越远,那么 loss 的惩罚力度越高。

GAN:WGAN-GP-带有梯度惩罚的WGAN_第9张图片

算法流程

GAN:WGAN-GP-带有梯度惩罚的WGAN_第10张图片

  • 训练 n_critic=5 次判别器,训练1次生成器
  • 训练判别器:
    • 采样一次真实数据x和生成数据\tilde{x}
    • 将真实数据x和生成数据\tilde{x}\varepsilon比例叠加混合,得到\hat{x}
    • \hat{x}输入判别器,得到混合图片数据的梯度,对梯度计算 norm,看看这个 norm 离单位距离 1 有多远(离1越近,惩罚越小)

对于上面第2点,为什么要用真假数据进行一个插值处理?这篇文章的解释: 要求 ‖T‖L ≤ 1 在每一处都成立,所以数据应该是全空间的均匀分布才行, 显然这很难做到。所以作者采用了一个非常机智(也有点流氓)的做法: 在真假样本之间随机插值来惩罚,这样保证真假样本之间的过渡区域满足 1-Lipschitz 约束。

移除判别器中BN

 大多数GANs中在生成器和判别器中均使用BN,目的是稳住训练过程。但WGAN-GP中移除了判别器中的BN操作: 因为WGAN-gp的惩罚项计算中,惩罚的是单个数据的gradient norm,如果使用 batchNorm,就会扰乱这种惩罚,让这种特别的惩罚失效。作者发现移除后效果很好。除了移除BN外,也可以使用Layer normalization 来替代 batch normalization。

GAN:WGAN-GP-带有梯度惩罚的WGAN_第11张图片

实验部分

1:wgan-gp在各种架构和条件下都可以成功训练:有无BN,网络深度等

GAN:WGAN-GP-带有梯度惩罚的WGAN_第12张图片

GAN:WGAN-GP-带有梯度惩罚的WGAN_第13张图片

2:优化器选择:作者重新对比了Adam、RMSProp。发现基于wgan-gp架构,Adam表现的更好一些(这与wgan中是完全相反的)

GAN:WGAN-GP-带有梯度惩罚的WGAN_第14张图片

代码学习

wgan:https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan/wgan.py

wgan-gp:https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py

1:生成器和判别器没有变化 。这个代码里面是没有BN操作的。如果判别器有,最好是移除。

GAN:WGAN-GP-带有梯度惩罚的WGAN_第15张图片

 2:lambda_gp = 10 的参数。同时优化器换回了Adam,作者验证发现Adam还是比RMSprop优化器效果好一些。

GAN:WGAN-GP-带有梯度惩罚的WGAN_第16张图片

 3:梯度惩罚的实现

GAN:WGAN-GP-带有梯度惩罚的WGAN_第17张图片

4:c-p和g-p的判别器实现 

GAN:WGAN-GP-带有梯度惩罚的WGAN_第18张图片

5:生成器实现,没有区别 

GAN:WGAN-GP-带有梯度惩罚的WGAN_第19张图片

参考

1:wgan笔记

2:wgan-gp 

你可能感兴趣的:(生成对抗网络,人工智能,神经网络)