生成式对抗网络(GAN,Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。学界大牛Yann Lecun 曾说,令他最激动的深度学习进展就是生成式对抗网络。最近正好看了这方面的一些介绍和论文,并用Tensorflow实现了两个小例子,所以写了这篇文章来作个简单的小结。
本文主要分为四个部分:
1.原始的GAN原理介绍;
2.GAN衍生的CGAN、DCGAN、WGAN、LSGAN、BEGAN等原理;
3.应用介绍;
4.tensorflow实现GAN小例子;
一、GAN原理介绍
学习GAN的第一篇论文当然由是 Ian Goodfellow 于2014年发表的 Generative Adversarial Networks(论文下载链接arxiv:[https://arxiv.org/abs/1406.2661] ),这篇论文可谓这个领域的开山之作。
GAN的基本原理其实并不复杂,模型通过框架中两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:
G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。
D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
在训练过程中,生成网络G的目标就是尽可能生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。
最后博弈的结果,在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,即达到了一个纳什均衡,因此D(G(z)) = 0.5。此时,模型的收敛目标是生成器能够从随机噪声生成真实数据。
这样我们的目的就达成了:我们得到了一个生成式的模型G,它可以用来生成图片。
以上只是大致说了一下GAN的核心原理,如何用数学语言描述呢?这里直接摘录论文里的公式:
简单分析一下这个公式:
整个式子由两项构成。x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。
D(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。而D(G(z))是D网络判断G生成的图片的是否真实的概率。
G的目的:上面提到过,D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”。也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小。因此我们看到式子的最前面的记号是min_G。
D的目的:D的能力越强,D(x)应该越大,D(G(x))应该越小。这时V(D,G)会变大。因此式子对于D来说是求最大(max_D)
下面这幅图片很好地描述了这个过程:
那么如何用随机梯度下降法训练D和G?论文中也给出了算法:
这里黄色框圈出的部分是我们要特别注意的。第一步我们训练D,D是希望V(G, D)越大越好,所以是加上梯度(ascending)。第二步训练G时,V(G, D)越小越好,所以是减去梯度(descending)。整个训练过程交替进行。
二、GAN衍生的CGAN、DCGAN、WGAN、LSGAN、BEGAN等原理
(1)CGAN:
在原始GAN中,目的是使得生成器能够从随机噪声中生成真实数据,而CGAN(论文下载链接arxiv:https://arxiv.org/pdf/1411.1784.pdf)则更近一层,即给GAN加上条件,指导数据的生成过程,使得生成具有特定性质的样本。以生成MNIST数据集的图像样本来说,原始GAN得到的生成器可以由随机向量生成一张含有数字的图像样本,其中数字可能是0~9中的任意一个,而CGAN则是在生成器输入时添加一个条件y,使得可以生成符合预期数字的图像样本,如生成含有数字1的图像。如图Figure1所示。价值函数变化如下:
(2)DCGAN:
DCGAN(论文下载链接arxiv:https://arxiv.org/abs/1511.06434 )是应用比较广泛的改进结构,基本采用卷积层替代了原始的全连接层,其中在生成器中采用带步长的卷积代替了上采样,极大地提升了GAN训练时的稳定性及生成结果质量。如图所示
GAN的主要问题是训练过程不稳定,而DCGAN改进了其稳定性,原因在于:
(1)几乎每层都使用batchnorm层,将特征层的输出归一化到一起,加速训练,提升训练的稳定性;
(2)判别器中使用Leaky ReLU,防止梯度过度稀疏,生成器则仍然采用 ReLU,但最后输出层采用Tanh;
(3)使用Adam优化器训练,且最佳学习率为0.0002;
(4)使用带步长卷积替代上采样层,卷积在提取图像特征上有较好的作用,并且使用卷积代替全连接层。
(3)WGAN:
为了使得GAN的训练更加稳定,与DCGAN不同的是,WGAN(论文下载链接arxiv:https://arxiv.org/pdf/1701.07875.pdf)主要从损失函数的角度进行改进:
A)判别器最后一层去掉Sigmoid;
B)生成器和判别器的loss不取Log;
C)对更新后的权重强制clip,如[-0.01,0.01],以满足连续性条件;
D)推荐SGD、RMSProp等优化器,不要采用含有动量的优化算法,如Adam。
原始的GAN存在的问题有:判别器越好,生成器梯度消失越严重,生成器loss降不下去;判别器不好,生成器梯度不准,训练不稳定,只有判别器训练得不好不坏才行,但这个尺度很难把握,甚至同一轮训练的不同阶段该尺度都不一样,所以GAN才难以训练。最小化生成器loss函数,会等价于最小化一个不合理的距离度量,使得最小化生成分布与真实分布的KL散度的同时又要最大化两者的JS散度,导致梯度不稳定,同时也会使得生成器宁可多生成一些重复但较为“安全”的样本,也不愿意生成多样性的样本,从而导致模式崩溃,即多样性不足。
下图所示为标准GAN与WGAN对真实样本分布和生成样本分布判别的差异,标注GAN会出现梯度消失的情况,而WGAN则有较好的线性梯度。
WGAN的贡献主要在于从理论上给出了GAN训练不稳定的原因,即交叉熵不适合衡量具有不相交部分的数据之间的距离,转而使用Wassertein距离去衡量生成数据与真实数据之间的距离,理论上解决了训练不稳定的问题;解决了模式崩溃问题,生成结果更加多样;对GAN的训练提供了一个指标,可以采用此指标来衡量GAN训练的好坏,而不像之前那样盲目训练。
(4)LSGAN:
LSGAN(论文下载链接arxiv:https://arxiv.org/pdf/1611.04076.pdf)的主要目的也是采用最小二乘损失函数代替了GAN目标函数的交叉熵,从而解决了GAN训练不稳定和生成图像质量差、多样性不足的问题。
其中a,b,c属于超参数,a,b分别表示生成图片和真实图片的标记,c是生成器为了使判别器认为生成图片为真实样本而定的值,这里设定a=0,b=c=1。
论文主要回答了两个问题:为什么最小二乘损失可以提高生成图片质量;为什么最小二乘损失可以使得GAN训练更稳定。对于第一个问题,论文认为交叉熵作为损失函数,会使得生成器不再优化那些被判别器识别为真实图片的生成图片,即使这些生成图片距离判别器的决策边界仍较远。原因在于生成器只需要完成混淆判别器的目标生成图片即可,而最小二乘损失则在混淆判别器的前提下还得让生成器把距离决策边界较远的生成图片拉向决策边界。对于第二个问题,论文认为Sigmoid交叉熵损失容易达到饱和状态,即梯度为0,而最小二乘只在一个点达到饱和。
(5)BEGAN:
谷歌提出一种新的简单强大的GAN,这是一种新的评价生成器生成质量的方法,不需要太多的训练技巧即可实现快速稳定的训练。以往的GAN及其变体是希望生成器生成的数据分布尽可能地接近真实数据分布,因此研究者们设计了各种损失函数,而BEGAN则不采用这种估计概率分布的方法,即不直接去估计生成分布Pg和真实分布Pdata的差距,而是估计分布的误差分布差距,只要分布之间的误差分布相近,也可以认为这些分布是相近的。
BEGAN主要有3个贡献:
(1)提出了一种新的简单强大的GAN网络结构,使用标准的训练方式也能快速稳定的收敛。
(2)对于生成器和判别器的平衡提出了一种均衡的概念,提供了一个超参数,这个超参数用于平衡图像的多样性和生成质量。
(3)受WGAN启发,提出了一种收敛程度估计。
BEGAN采用自编码器作为判别器;在生成器的设计上,使用Wasserstein距离衍生出的损失去匹配自编码器的损失分布,这是通过传统的GAN目标加上一个用来平衡判别器和生成器的平衡项来实现的;还提出了一个衡量生成样本多样性的超参数Y:生成样本损失的期望与真实样本损失的期望值之比。Y值较低时会导致图像多样性较差,因为此时判别器过于关注对真实图像的自编码。
三、应用介绍
自诞生以来,GAN引起了众多学者的注意,成为近几年的热点研究领域,原因在于其代表的无监督学习范式有着广阔的前景。目前生成式对抗网络已经有了许多成功的应用,如图片生成、文字到图片的合成、图像超分辨率重建、图像修复和纹理合成、风格迁移等,此外,GAN在目标检测、行人识别、重定位等领域也有辅助作用。
因为自己也是刚刚入门学习中,了解的还很片面,感兴趣的小伙伴可以百度搜索进一步深入学习。当然,也欢迎扫描文章末尾的二维码关注公众号“StrongerTang”做交流和分享。
四、tensorflow实现GAN小例子
学习完以上内容以后,本人参考网上分享的代码用tensorflow实现了Mnist数据图像生成等两个简单小例子,考虑到本文篇幅已经较多,故打算另外单独写一篇文章分享。感兴趣的朋友可以后续关注一下。
五、小结
GAN自诞生以来便成为了研究热点,无论是原理还是应用都取得了极大丰富和发展,并且仍在不断向前发展中。因而,本文只是冰山一角的分享,还有很多内容自己也还没有学习,更谈不上分享。只希望自己能够克服爱玩的缺点,多花点心思在学习上,能够和大家一起学习到更多的知识。
最后这里分享一个对GAN总结比较好的链接,感兴趣的小伙伴可以进一步学习。(https://github.com/savan77/The-GAN-World)
注:本文参考众多论文原文及其它网络资料,在此表示感谢。