WGAN-GP:进阶的WGAN

前言

​   从GAN到WGAN,我们通过使用新的距离(推土机距离)来衡量 P g P_g Pg P r P_r Pr的到底有多接近,并且使用了权重裁剪来使得我们的生成器满足约束,但是作者也在论文中提到了这是一种暴力的手段,在现实的实验过程中,我们也发现Critic的权重会逐渐向裁剪因子靠拢,出现两极化的现象(如下图 c = 0.01),然而本身 c c c的取值很难有一个界定,我们必须不断地调节这个参数来寻找使得网络收敛到最佳,这本身是困难的,而且在实验中,我发现WGAN的收敛速度远低于原始GAN,虽然不会出现mode collapse和生成器无法训练的现象。

WGAN-GP:进阶的WGAN_第1张图片

​ 所以,WGAN 的作者提出了改进的算法WGAN-GP, 并且对之前的工作进行了总结,原文如下:

​ Generative Adversarial Networks (GANs) are powerful generative models, but suffer from training instability. The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only poor samples or fail to converge. We find that these problems are often due to the use of weight clipping in WGAN to enforce a Lipschitz constraint on the critic, which can lead to undesired behavior. We propose an alternative to clipping weights: penalize the norm of gradient of the critic with respect to its input. Our proposed method performs better than standard WGAN and enables stable training of a wide variety of GAN architectures with almost no hyper-parameter tuning, including 101-layer ResNets and language models with continuous generators. We also achieve high quality generations on CIFAR-10 and LSUN bedrooms.

WGAN:权重裁剪好吗?

  • 无法拟合高维数据的真实分布

    ​   在下图中第二行第一列,WGAN-GP生成数据的critic值基本都分布在8个高斯附近,这和输入的样本信息(8个高斯)是一致的。但是第一行第一列中,WGAN-GP生成数据的critic值就没有如此优良特性。

WGAN-GP:进阶的WGAN_第2张图片

  • 梯度消失和梯度爆炸

    ​   如前言图中所示,实验中真实的数据分布会趋向于裁剪因子的两端,如下图中,作者也给出了不同权重裁剪因子下的梯度更新情况,如果当我们给Discriminator设置一个较大的裁剪因子时,会在Discriminator的BP过程中出现梯度爆炸的情况,相反,裁剪因子过小会出现梯度消失的现象。如果在极端情况下,如下图右上角中,会出现最后的权重被暴力裁剪到二值,这说明我们最后在优化一个非常简单的网络,这是非常棘手的。而通过改进后的权重惩罚使得梯度在BP后的各个层中,都保持比较平稳的更新。

WGAN-GP:进阶的WGAN_第3张图片

  • 收敛速度缓慢

    ​   我们用同样的优化器(RMSProp)和学习率训练一个模型,使用WGAN权重裁剪,另一个模型与Adam和更高的学习率。即使使用相同的优化器,我们的方法比权值剪裁收敛得更快,得分也更好。使用Adam可以进一步提高性能。

    WGAN-GP:进阶的WGAN_第4张图片

WGAN-GP

  • 整体算法

    WGAN-GP:进阶的WGAN_第5张图片

  • 改进点

    ​   WGAN-GP 在WGAN的基础上,去除了权重裁剪项。但是为了使得critic满足1-Lipschtiz约束,在Critic的loss中使用了梯度惩罚,而这样的方法是等同的,因为,如果一个可微的函数D满足1-Lipschtiz约束,那么它等同于该函数D相对于所有的输入的梯度的范数小于等于1。但是,我们的D在对输入(包含生成样本和真实样本)进行求梯度导时,发现这是非常困难的,因为输入的维度太高了,难以计算。

    WGAN-GP:进阶的WGAN_第6张图片
    ​   所以,作者在这一块采用了随机插值。先在生成数据的分布和真实数据的分布各取一个点,然后,连接两个点,在这条直线进行随机抽样,这样的采样效果在实验上保持较好的性能。理论上,我们在采样选取的点正好是我们 p g p_g pg向真实分布 p r p_r pr过渡的那一部分,而这一部分也对我们有影响的那一部分数据。

WGAN-GP:进阶的WGAN_第7张图片

​   除此之外,作者使用了双边惩罚(Two-sides penalty),也即修改惩罚项使得它对梯度范数大于和小于1进行,实验上双边惩罚比单边惩罚可以更快的收敛并且能达到最优。

WGAN-GP:进阶的WGAN_第8张图片

注意在Critic 中不能使用BN,推荐使用LB,优化器可以使用Adam,并且有较好的收敛性。

No critic batch normalization Most prior GAN implementations use batch normalization in both the generator and the discriminator to help stabilize training, but batch normaliz-ationchanges the form of the discriminator’s problem from mapping a single input to a single output tomapping from an entire batch of inputs to a batch of outputs [23]. Our penalized training objectiveis no longer valid in this setting, since we penalize the norm of the critic’s gradient with respectto each input independently, and not the entire batch. To resolve this, we simply omit batch normalization in the critic in our models, finding that they perform well without it. Our method works with normalization schemes which don’t introduce correlations between examples. In particular, we recommend layer normalization [3] as a drop-in replacement for batch normalization.

实验

改动点

  • Gradient penalty

    def cal_grad_penalty(critic, real_samples, fake_samples):
        """计算critic的惩罚项"""
    
        # 定义alpha
        alpha = t.Tensor(np.random.randn(real_samples.size(0), 1, 1, 1)).to(device)
    
        # 从真实数据和生成数据中的连线采样
        interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True).to(device)
        d_interpolates = critic(interpolates) #  输出维度:[B, 1]
        
        fake = t.autograd.Variable(t.Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = fake.to(device)
    
        # 对采样数据进行求导
        gradients = t.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]  # 返回一个元组(value, )
    
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean().to(device)
    
        return gradient_penalty
    
  • Loss

    critic

    grad_penalty = cal_grad_penalty(critic, real_imgs.data, fake_imgs.data)
    c_loss = -t.mean(c_real) + t.mean(c_fake) + opt.lambda_gp * grad_penalty
    
  
**generator**
  
  ```python
  g_loss = -t.mean(critic(fake_imgs))

可视化

  • 权值和梯度(critic)

    ​   从下图中可以观察到,在layer 1 和layer 8中均为出现梯度消失的梯度爆炸的现象。

    WGAN-GP:进阶的WGAN_第9张图片 WGAN-GP:进阶的WGAN_第10张图片
    layer 1 wight layer 1 grad
    WGAN-GP:进阶的WGAN_第11张图片 WGAN-GP:进阶的WGAN_第12张图片
    layer 8 wight layer 8 grad
  • Loss

    WGAN-GP:进阶的WGAN_第13张图片
    critic loss
    WGAN-GP:进阶的WGAN_第14张图片
    generator loss
  • 生成图片效果

iter = 50 iter =150
iter =650 iter =1100

结论

​   作者通过改变Critic的1-Lipschtiz约束, 使得GAN的收敛更加稳定,并且在其他网络结构上也表现出较好的性能,使得GAN的训练不在变得很难。本次,只是在DCGAN的网络结构上进行复现,虽然,未能生成较高质量的图片,最近在打算试一试ResNet网络,看看效果怎么样。代码后续上传。

参考

paper

WGAN-GP方法介绍

GAN-Pytorch

WGAN-GP(改进的WGAN)介绍

DCGAN、WGAN、WGAN-GP、LSGAN、BEGAN原理总结及对比

你可能感兴趣的:(GAN)