摘要:生成对抗网络( Generative Adversarial Networks, GAN)是通过对抗训练的方式来使得生成网络产生的样本服从真实数据分布。在生成对抗网络中,有两个网络进行对抗训练。一个是判别器,目标是尽量准确地判断一个样本是来自于真实数据还是由生成器产生;另一个是生成器,目标是尽量生成判别网络无法区分来源的样本 。两者交替训练,当判别器无法判断一个样本是真实数据还是生成数据时,生成器即达到收敛状态。以上是对生成对抗网络的简单描述,本文将对生成对抗网络的内在原理以及相应的优化机制进行介绍。
文章概览
概率生成模型
-
生成对抗网络
- 生成对抗网络的理论解释
- 生成对抗网络的求解过程
-
生成对抗网络的优化
- fGAN
- WGAN
-
生成对抗网络的实现
- GAN
- CGAN
- WGAN
概率生成模型
概率生成模型,简称生成模型,是指一系列用于随机生成可观测数据的模型。假设在一个连续或离散的空间中,存在一个随机向量服从一个未知的数据分布,。生成模型是根据一些可观测的样本来学习m一个参数化模型来近似未知分布,并可以用这个模型来生成一些样本,使得生成的样本和真实的样本尽可能的相似。对于一个低维空间中的简单分布而言,我们可以采用最大似然估计的方法来对进行求解。假设我们要统计全国人民的年均收入的分布情况,如果我们对每个样本都进行统计,这将消耗大量的人力物力。为了得到近似准确的收入分布情况,我们可以先假设其服从高斯分布,我们比如选取某个城市的人口收入,作为我们的观察样本结果,然后通过最大似然估计来计算上述假设中的高斯分布的参数。
由于服从高斯分布,我们将其带入即可求得最终的近似的分布情况。下面我们对上述过程进行一些拓展,我们从尽可能采样更多的数据,此时可以得到
对该式进行一些变换,可以得到
由此可以看出,最大似然估计的过程其实就是最小化分布和分布之间散度的过程。从本质上讲,所有的生成模型的问题都可以转换成最小化分布和分布之间距离的问题,散度只是其中一种度量方式。
如上所述,对于低维空间的简单分布而言,我们可以显式的假设样本服从某种类型的分布,然后通过极大似然估计来进行求解。但是对于高维空间的复杂分布而言,我们无法假设样本的分布类型,因此无法采用极大似然估计来进行求解,生成对抗网络即属于这样一类生成模型。
生成对抗网络
生成对抗网络的理论解释
在生成对抗网络中,我们假设低维空间中样本服从标准类型分布,利用神经网络可以构造一个映射函数(即生成器)将映射到真实样本空间。我们希望映射函数能够使得分布尽可能接近分布,即与之间的距离越小越好:
由于与的分布都是未知的,所以无法直接求解与之间的距离。生成对抗网络借助判别器来解决这一问题。首先我们分别从与中取样,利用取出的样本训练一个判别器:我们希望当输入样本为时,判别器会给出一个较高的分数;当输入样本为时,判别器会给出一个较低的分数。例如,我们可以将判别器的目标函数定义成以下形式(与二分类的目标函数一致,即交叉熵):
我们希望得到这样一个判别器(固定):
从本质上来看,即表示与之间的散度(具体推导参见李宏毅老师的课程),即:
因此通过构建判别器可以度量与之间的距离,所以可以表示为:
生成对抗网络的求解过程
的求解过程大致如下:
- 初始化生成器和判别器
- 迭代训练
- 固定生成器,更新判别器的参数
- 固定生成器,更新判别器的参数
对上述算法过程进行几点说明:
- 在之前的描述中,表示的是目标函数的期望,但在实际计算过程中是通过采样平均的方式来逼近其期望值。
- 判别器的训练需要重复次的原因是希望能尽可能使得接近最大值,这样才能满足"即表示与之间的散度"这一假设。
- 在更新生成器参数时,这一项可以忽略,因为固定,其相当于一个常数项。
- 在更新生成器参数时,我们使用代替,这样做的目的是加速训练过程。
生成对抗网络的优化
fGAN
通过上面的分析我们可以知道,构建生成模型需要解决的关键问题是最小化和之间的距离,这就涉及到如何对和之间的距离进行度量。在上述GAN的分析中,我们通过构建一个判别器来对和之间的距离进行度量,其中采用的目标函数为:
通过证明可知,其实度量的是和之间的散度。如果我们希望采用其他方式来衡量两个分布之间的距离,则需要对判别器的目标函数进行修改。根据论文fGAN,可以将判别器的目标函数定义成如下形式:
则可以表示为:
取不同表达式时,即表示不同的距离度量方式。
令,取,代入即可得到。
WGAN
自2014年Goodfellow提出以来,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。针对这些问题,Martin Arjovsky进行了严密的理论分析,并提出了解决方案,即WGAN(WGAN的详细解读可参考这篇博客)。
判别器越好,生成器梯度消失越严重。根据上面的分析可知,当判别器训练到最优时,衡量的是与之间的散度。问题就出在这个JS散度上,我们希望如果两个分布之间越接近它们的JS散度越小,通过优化JS散度就能将拉向。这个希望在两个分布有所重叠的时候是成立的,但是如果两个分布完全没有重叠的部分,或者它们重叠的部分可忽略,。在训练过程中,与都是通过采样得到的,在高维空间中两者之间几乎不存在交集,从而导致接近于0,生成器因此也无法得到有效训练。
最小化生成器loss函数,会等价于最小化一个不合理的距离衡量,导致两个问题,一是梯度不稳定,二是collapse mode即多样性不足。假设当前的判别器最优,经过推导可以得到下面等式:
这个等价最小化目标存在两个严重的问题。第一是它同时要最小化生成分布与真实分布的KL散度,却又要最大化两者的JS散度,一个要拉近,一个却要推远!这在直观上非常荒谬,在数值上则会导致梯度不稳定,这是后面那个JS散度项的毛病。第二,即便是前面那个正常的KL散度项也有毛病,因为KL散度不是一个对称的衡量和是有差别的。原始GAN的主要问题就出在距离度量方式上面,Martin Arjovsky提出利用Wasserstein距离来进行衡量。Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。
由以上算法可以看出,WGAN与原始的GAN在算法实现方面只有四处不同:(1)判别器最后一层去掉sigmoid;(2)生成器和判别器的loss不取log;(3)每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c;(4)不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行。
生成对抗网络的实现
本文实现了几种常见的生成对抗网络模型,包括原始GAN、CGAN、WGAN、DCGAN。开发环境为jupyter lab,所使用的深度学习框架为pytorch,并结合tensorboard动态观测生成器的训练效果,具体代码请参考我的github。
GAN
real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)
# 训练判别器
d_real = D(real_img)
d_real_loss = criterion(d_real, real_label)
z = torch.normal(0, 1, (batch_size, latent))
fake_img = G(z)
d_fake = D(fake_img)
d_fake_loss = criterion(d_fake, fake_label)
optimizer_D.zero_grad()
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
optimizer_D.step()
# 训练生成器
fake_img = G(z)
d_fake = D(fake_img)
g_loss = criterion(d_fake, real_label)
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
CGAN
real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)
z = torch.normal(0, 1, (batch_size, latent))
# 训练判别器
d_real = D(real_img, label)
d_real_loss = criterion(d_real, real_label)
fake_img = G(z, label)
d_fake = D(fake_img, label)
d_fake_loss = criterion(d_fake, fake_label)
optimizer_D.zero_grad()
d_loss = (d_real_loss + d_fake_loss)
d_loss.backward()
optimizer_D.step()
# 训练生成器
fake_img = G(z, label)
d_fake = D(fake_img, label)
g_loss = criterion(d_fake, real_label)
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
WGAN
# 训练判别器
d_real = D(real_img)
#d_real_loss = criterion(d_real, real_label)
d_real_loss = d_real
z = torch.normal(0, 1, (batch_size, latent))
fake_img = G(z)
d_fake = D(fake_img)
#d_fake_loss = criterion(d_fake, fake_label)
d_fake_loss = d_fake
optimizer_D.zero_grad()
#d_loss = d_real_loss + d_fake_loss
d_loss = torch.mean(d_fake_loss) - torch.mean(d_real_loss)
d_loss.backward()
optimizer_D.step()
for p in D.parameters():
p.data.clamp_(-clip_value, clip_value)
# 训练生成器
fake_img = G(z)
d_fake = D(fake_img)
#g_loss = criterion(d_fake, real_label)
g_loss = - torch.mean(d_fake)
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()