GAN学习

开始学习GAN生成对抗网络相关知识,将要点和心得总结于此。

文章目录

    • 起源
    • 主要思想
    • 特点
    • 训练技巧
    • 应用场景
    • 其他
    • GAN及其改进
      • GAN
      • DCGAN
      • WGAN和WGAN-gp
      • LSGAN
      • cGAN
      • pix2pix
      • CycleGAN

起源

GAN,全名 Generative Adversarial Networks,即生成式对抗网络,是2014年Lan Goodfellow的论文《Generative Adversarial Nets》中提出的一种新的方法,是一种无监督学习模型,通过学习样本分布让算法生成类似分布的图片。

主要思想

GAN的主要灵感来源于博弈论中零和博弈的思想。通过生成网络G(Generator)和判别网络D(Discriminator)不断博弈,进而使G学习到数据的分布,根据一定的映射规则从一段随机数中生成逼真的图像。
G是一个生成网络,输入为一个随机的噪声,输出为的生成图像。
D是一个判别网络,输入为一张图片,输出为真实图片的概率,范围为0-1。
训练过程中,G的目标就是尽量生成真实的图片去欺骗D。而D的目标就是尽量辨别出G生成的假图像。这样,G和D构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点
G的梯度更新信息来自判别器D,而不是来自数据样本。

特点

GAN 的优点:

  1. GAN是一种生成式模型,相比较其他生成模型(玻尔兹曼机和GSNs)只用到了反向传播,而不需要复杂的马尔科夫链。
  2. 相比其他所有模型, GAN可以产生更加清晰,真实的样本
  3. GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域
  4. 相比于变分自编码器, GANs没有引入任何决定性偏置( deterministic bias),变分方法引入决定性偏置,因为他们优化对数似然的下界,而不是似然度本身,这看起来导致了VAEs生成的实例比GANs更模糊
  5. 相比VAE, GANs没有变分下界,如果鉴别器训练良好,那么生成器可以完美的学习到训练样本的分布.换句话说,GANs是渐进一致的,但是VAE是有偏差的
  6. GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,不管三七二十一,只要有一个的基准,直接上判别器,剩下的就交给对抗训练了。

GAN的缺点:

  1. 训练GAN需要达到纳什均衡,有时候可以用梯度下降法做到,有时候做不到.我们还没有找到很好的达到纳什均衡的方法,所以训练GAN相比VAE或者PixelRNN是不稳定的,但我认为在实践中它还是比训练玻尔兹曼机稳定的多
  2. GAN不适合处理离散形式的数据,比如文本
  3. GAN存在训练不稳定、梯度消失、模式崩溃(model collapse)的问题(目前已解决)

模式崩溃(model collapse):生成的数据多样性不足。原GAN论文中提出的loss函数经过变换后为KL散度项,KL散度不具有对称性,即KL(A|B)≠KL(B|A)。GAN学习_第1张图片
故在优化过程中loss对于两种错误的惩罚不同,第一种错误表示样本中包含的数据没有被生成,即缺乏多样性,惩罚微小;第二种错误表示生成的数据在样本中不存在 ,即缺乏准确性,惩罚巨大。由于不平衡的惩罚导致生成器宁可多生成一些重复但是正确的样本,也不愿意去生成多样性的样本,因为那样一不小心就会产生第二种错误。这种现象就是大家常说的collapse mode。

训练技巧

  1. 输入规范化到(-1,1)之间,最后一层的激活函数使用tanh。
  2. 使用wassertein GAN的损失函数。
  3. 使用batch norm 或者instance norm 或者weight norm。
  4. 避免使用Relu和pooling层,可以使用Leaky-Relu激活函数以减少稀疏梯度的可能性。
  5. 梯度下降算法选用Adam,学习率初始参考值1e-4。
  6. 给判别网络D输入端增加高斯噪声(正则化)。

应用场景

  1. GAN本身是一种生成式模型,最常见的是图片生成。
  2. GAN在分类领域也占有一席之地。替换判别器为一个分类器,做多分类任务,生成器辅助分类器训练。
  3. GAN可以和强化学习结合,例如seq-GAN。
  4. GAN在图像风格迁移,图像降噪修复,图像超分辨率都有比较好的结果,详见pix-2-pix GAN 和cycle GAN。
  5. 目前也有研究者将GAN用在对抗性攻击上,就是训练GAN生成对抗文本,有针对或者无针对的欺骗分类器或者检测系统。

GAN应用汇总
常见GAN变体及实现

其他

为什么GAN中的优化器不常用SGD

  1. SGD容易震荡,容易使GAN训练不稳定,
  2. GAN的目的是在高维非凸的参数空间中找到纳什均衡点,GAN的纳什均衡点是一个鞍点,但是SGD只会找到局部极小值,因为SGD解决的是一个寻找最小值的问题,GAN是一个博弈问题。

为什么GAN不适合处理文本数据

  1. 文本数据相比较图片数据来说是离散的,因为对于文本来说,通常需要将一个词映射为一个高维的向量,最终预测的输出是一个one-hot向量,假设softmax的输出是(0.2, 0.3, 0.1,0.2,0.15,0.05)那么变为onehot是(0,1,0,0,0,0),如果softmax输出是(0.2, 0.25, 0.2, 0.1,0.15,0.1 ),one-hot仍然是(0, 1, 0, 0, 0, 0),所以对于生成器来说,G输出了不同的结果但是D给出了同样的判别结果,并不能将梯度更新信息很好的传递到G中去,所以D最终输出的判别没有意义。
  2. 另外就是GAN的损失函数是JS散度,JS散度不适合衡量不想交分布之间的距离。(WGAN虽然使用wassertein距离代替了JS散度,但是在生成文本上能力还是有限,GAN在生成文本上的应用有seq-GAN,和强化学习结合的产物)

GAN及其改进

GAN

GAN学习_第2张图片

如上图所示,生成对抗网络会训练并更新判别分布(即 D,蓝色的虚线),更新判别器后就能将数据真实分布(黑点组成的线)从生成分布 P_g(G)(绿色实线)中判别出来。下方的水平线代表采样域 Z,其中等距线表示 Z 中的样本为均匀分布,上方的水平线代表真实数据 X 中的一部分。向上的箭头表示映射 x=G(z) 如何对噪声样本(均匀采样)施加一个不均匀的分布 P_g。(a)考虑在收敛点附近的对抗训练:P_g 和 P_data 已经十分相似,D 是一个局部准确的分类器。(b)在算法内部循环中训练 D 以从数据中判别出真实样本,该循环最终会收敛到 D(x)=P_data(x)/(P_data(x)+P_g(x))。(c)随后固定判别器并训练生成器,在更新 G 之后,D 的梯度会引导 G(z)流向更可能被 D 分类为真实数据的方向。(d)经过若干次训练后,如果 G 和 D 有足够的复杂度,那么它们就会到达一个均衡点。这个时候 P_g=P_data,即生成器的概率密度函数等于真实数据的概率密度函数,也即生成的数据和真实数据是一样的。在均衡点上 D 和 G 都不能得到进一步提升,并且判别器无法判断数据到底是来自真实样本还是伪造的数据,即 D(x)= 1/2。

GAN学习_第3张图片
具体算法实现
GAN学习_第4张图片
参考资料:
机器之心GitHub项目:GAN完整理论推导与实现
KL散度、JS散度以及交叉熵对比
Generative Adversarial Nets(译)

DCGAN

将GAN与CNN相结合,将原论文中的MLP网络更换为CNN网络,改善了对图片的生成与判别效果。
主要贡献是:
为GAN的训练提供了一个很好的网络拓扑结构。
表明生成的特征具有向量的计算特性。
使用的CNN结构如下
GAN学习_第5张图片
判别器几乎是和生成器对称的。整个网络没有pooling层和上采样层,实际上是使用了带步长(fractional-strided)的卷积代替了上采样,以增加训练的稳定性。
DCGAN能改进GAN训练稳定的原因主要有:

  • 使用步长卷积代替上采样层,卷积在提取图像特征上具有很好的作用,并且使用卷积代替全连接层。
  • 生成器G和判别器D中几乎每一层都使用batchnorm层,将特征层的输出归一化到一起,加速了训练,提升了训练的稳定性。(生成器的最后一层和判别器的第一层不加batchnorm)
  • 在判别器中使用leaky-ReLU激活函数,而不是ReLU,防止梯度稀疏,生成器中仍然采用ReLU,但是输出层采用tanh
  • 使用adam优化器训练,学习率推荐为0.0002

参考资料:
DCGAN、WGAN、WGAN-GP、LSGAN、BEGAN原理总结及对比
DCGAN在TF上实现

WGAN和WGAN-gp

Wasserstein距离

WGAN
GAN定义的损失函数具有一定的缺陷,具体表现为:训练不稳定,存在collapse mode情况,并且在训练目标不明确的问题。针对这些缺陷,WGAN提出了用Wasserstein距离替代JS散度,用于计算真实样本与生成样本之间的差异。将判别器的作用从判断样本是否为真转化为计算真实样本与生成样本之间Wasserstein距离,通过不断减少这个距离可以优化生成器。Wasserstein距离也是一个明确的指示标志,表面了当前模型训练情况,Wasserstein距离越小说明生成样本与真实样本越接近。此外,由于Wasserstein距离具有对称性,还(基本上)解决了collapse mode情况。
具体算法如下
GAN学习_第6张图片
WGAN-gp
WGAN的梯度裁剪的方法具有一定的弊端,会产生下图左侧情况
GAN学习_第7张图片
容易引起梯度消失或者梯度爆炸情况,因此使用惩罚系数代替梯度裁剪。具体方法是对损失函数加入梯度惩罚项,当梯度大于1时进行惩罚,保证Lipschitz连续性限制。该惩罚项的梯度位置在真实样本与生成样本中的连线中随机采样某一点,然后计算D(x)并求梯度,最后计算与1的距离。具体形式为:
WGAN-gp
令人拍案叫绝的Wasserstein GAN
W-GAN系 (Wasserstein GAN、 Improved WGAN)
WGAN在TF上实现

LSGAN

将GAN的损失函数更换为最小二乘损失函数,其目的与WGAN类似,即JS散度具有不对称性和范围(0-1),因此不能拉近真实分布和生成分布之间的距离,使用最小二乘可以将图像的分布尽可能的接近决策边界。LSGAN损失函数定义如下:
m i n D J ( D ) = m i n D 1 2 E x ∼ P r [ D ( x ) − a ] 2 + 1 2 E z ∼ P z [ D ( G ( x ) ) − b ] 2 m i n G J ( G ) = m i n G 1 2 E z ∼ P z [ D ( G ( x ) ) − c ] 2 \underset{D}{min}J(D)=\underset{D}{min}\frac{1}{2}E_{x\sim P_{r}}[D(x)-a]^{2}+\frac{1}{2}E_{z\sim P_{z}}[D(G(x))-b]^{2}\\ \underset{G}{min}J(G)=\underset{G}{min}\frac{1}{2}E_{z\sim P_{z}}[D(G(x))-c]^{2} DminJ(D)=Dmin21ExPr[D(x)a]2+21EzPz[D(G(x))b]2GminJ(G)=Gmin21EzPz[D(G(x))c]2作者设置a=c=1,b=0。
参考资料:GAN——LSGANs(最小二乘GAN)

cGAN

GAN的训练为无监督训练,生成的图片具有随机性。为了得到可控的结果,在生成器G与判别器D中均加入给定条件y。这里标签与生成图片进行堆叠送入判别器。
GAN学习_第8张图片
目标函数如下:
CGAN损失函数

参考资料:CGAN论文笔记
详解GAN代码之搭建并详解CGAN代码

pix2pix

在CGAN基础上的改进。为了使生成的图片更接近训练图片,加入和L1损失,其损失函数定义如下:
pix2pix
pix2pix
pix2pix损失函数
网络结构使用了U-Net结构,能够减少Encoder-Decoder过程中对于原始信息的丢失,其原理如下:GAN学习_第9张图片
将判别器改变为局部判别器(Patch-D),即将图像分为固定大小的部分送入判别器。
优点:

  1. 输入变小,计算量小,训练速度快。
  2. 生成器G是全卷积结构,对图像尺度没有限制;Patch-D对图像大小也没有限制,这样整个网络对图像大小没有限制,增加了框架的扩展性。

参考资料:Pix2Pix-基于GAN的图像翻译
Image-to-Image Translation in Tensorflow

CycleGAN

你可能感兴趣的:(深度学习相关)