一文看懂生成对抗网络

(关注'AI新视野'公众号,发送‘资料’二字,免费获取50G人工智能视频教程!)

image

人工智能目前的核心目标应该是赋予机器自主理解我们所在世界的能力。对于人类来说,我们对这个世界所了解的知识可能很快就会忘记,比如我们所处的三维环境中,物体能够交互,移动,碰撞;什么动物会飞,什么动物吃草等等。这些巨大的并且不断扩大的信息现在是很容易被机器获取的,问题的关键是怎么设计模型和算法让机器更好的去分析和理解这些数据中所蕴含的宝藏。

Generative models(生成模型)现在被认为是能够实现这一目标的最有前景的方法之一。Generative models通过输入一大堆特定领域的数据进行训练(比如图像,句子,声音等)来使得模型能够产生和输入数据相似的输出。这一直觉的背后可以由下面名言阐述。

“What I cannot create, I do not understand.” —Richard Feynman

生成模型由一个参数数量比训练数据少的多神经网络构成,所以生成模型为了能够产生和训练数据相似的输出就会迫使自己去发现数据中内在的本质内容。训练Generative models的方法有几种,在这里我们主要阐述其中的Adversarial Training(对抗训练)方法。

Adversarial Training

上文说过Adversarial Training是训练生成模型的一种方法。为了训练生成模型,Adversarial Training提出一种Discriminative Model(判别模型)来和生成模型产生对抗,下面来说说Generative modelsDiscriminative Model 是如何相互作用的。

  • 生成模型的目标是模仿输入训练数据, 通过输入一个随机噪声来产生和训练数据相似的样本;
  • 判别模型的目标就是判断生成模型产生的样本和真实的输入样本之间的相似性。

其中生成模型和判别模型合起来的框架被称为GAN网络。通过下图我们来理清判别模型和生成模型之间的输入输出关系:生成模型通过输入随机噪声 产生合成样本;而判别模型通过分别输入真实的训练数据和生成模型的训练数据来判断输入的数据是否真实。

[图片上传失败...(image-6e8d00-1566991139458)]

描述了GAN的网络结构,但它的优化目标是什么?怎么就可以通过训练使得生成模型能够产生和真实数据相似的输出?优化的目标其实很简单,简单来说就是:

  • 判别模型努力的想把真实的数据预测为1,把生成的数据预测为0
  • 而生成模型的奋斗目标则为‘我’要尽力的让判别模型对‘我’生成的数据预测为1,让判别模型分不清‘我’产生的数据和真实数据之间的区别,从而达到‘以假乱真’的效果。

下面用形式化说明下如果训练GAN网络, 先定义一些参数:

参数 含义
输入随机噪声 的分布
未知的输入样本的数据分布
生成模型的输出样本的数据分布,GAN的目标就是要

训练判别模型 的目标:

  1. 对每一个输入数据 要使得 最大;
  2. 对每一个输入数据 要使得 最小。

训练生成模型 的目标是来产生样本来欺骗判别模型 , 因此目标为最大化 ,也就是把生成模型的输出输入到判别模型,然后要让判别模型预测其为真实数据。同时,最大化 等同于最小化 ,因为 的输出是介于0到1之间的,真实数据努力预测为1,否则为0。

所以把生成模型和判别模型的训练目标结合起来,就得到了GAN的优化目标:

总结一下上面的内容,GAN启发自博弈论中的二人零和博弈,在二人零和博弈中,两位博弈方的利益之和为零或一个常数,即一方有所得,另一方必有所失。GAN模型中的两位博弈方分别由生成模型和判别模型充当。生成模型G捕捉样本数据的分布,判别模型是一个二分类器,估计一个样本来自于训练数据(而非生成数据)的概率。G和D一般都是非线性映射函数,例如多层感知机、卷积神经网络等。生成模型的输入是一些服从某一简单分布(例如高斯分布)的随机噪声z,输出是与训练图像相同尺寸的生成图像。向判别模型D输入生成样本,对于D来说期望输出低概率(判断为生成样本),对于生成模型G来说要尽量欺骗D,使判别模型输出高概率(误判为真实样本),从而形成竞争与对抗。

GAN实现

一个简单的一维数据GAN网络的tensorflow实现:genadv_tutorial
其一维训练数据分布如下所示,是一个均值-1, 的正态分布。

image

我们结合代码和上面的理论内容来分析下GAN的具体实现,判别模型的优化目标为最大化下式,其中 表示判别真实数据, 表示对生成的数据进行判别, 其中 和 是共享参数的, 也就是说是同一个判别模型。

对应的python代码如下:

batch=tf.Variable(0)
obj_d=tf.reduce_mean(tf.log(D1)+tf.log(1-D2))
opt_d=tf.train.GradientDescentOptimizer(0.01)
              .minimize(1-obj_d,global_step=batch,var_list=theta_d)

为了优化 , 我们想要最大化 (成功欺骗 ),因此 的优化函数为:

对应的python代码:

batch=tf.Variable(0)
obj_g=tf.reduce_mean(tf.log(D2))
opt_g=tf.train.GradientDescentOptimizer(0.01)
              .minimize(1-obj_g,global_step=batch,var_list=theta_g)

定义好优化目标后,下面就是训练的主要代码了:

# Algorithm 1, GoodFellow et al. 2014
for i in range(TRAIN_ITERS):
    x= np.random.normal(mu,sigma,M) # sample minibatch from p_data
    z= np.random.random(M)  # sample minibatch from noise prior
    sess.run(opt_d, {x_node: x, z_node: z}) # update discriminator D
    z= np.random.random(M) # sample noise prior
    sess.run(opt_g, {z_node: z}) # update generator G

下面是实验的结果,左图是训练之间的数据,可以看到生成数据的分布和训练数据相差甚远;右图是训练后的数据分析,生成数据和训练数据的分布接近了很多,且此时判别模型的输出分布在0.5左右,说明生成模型顺利的欺骗到判别模型。

image

DCGAN

GAN的一个改进模型就是DCGAN。这个网络的生成模型的输入为一个100个符合均匀分布的随机数(通常被称为code),然后产生输出为64x64x3的输出图像(下图中 ), 当code逐渐递增时,生成模型输出的图像也逐渐变化。下图中的生产模型主要由反卷积层构成, 判别模型就由简单的卷积层组成,最后输出一个判断输入图片是否为真实数据的概率 。

image

下图为随着迭代次数,DCGAN产生图像的变化过程。

[图片上传失败...(image-5b9e2f-1566991139458)]

训练好网络之后,其中的生成模型和判别模型都有其他的作用。一个训练好的判别模型能够用来对数据提取特征然后进行分类任务。通过输入随机向量生成模型可以产生一些非常有意思的的图片,如下图所示,当输入空间平滑变化时,输出的图片也在平滑转变。

image

GAN的训练及其改进

上面使用GAN产生的图像虽然效果不错,但其实GAN网络的训练过程是非常不稳定的。
通常在实际训练GAN中所碰到的一个问题就是判别模型的收敛速度要比生成模型的收敛速度要快很多,通常的做法就是让生成模型多训练几次来赶上生成模型,但是存在的一个问题就是通常生成模型和判别模型的训练是相辅相成的,理想的状态是让生成模型和判别模型在每次的训练过程中同时变得更好。判别模型理想的minimum loss应该为0.5,这样才说明判别模型分不出是真实数据还是生成模型产生的数据。

Improved GANs

Improved techniques for training GANs这篇文章提出了很多改进GANs训练的方法,其中提出一个想法叫Feature matching,之前判别模型只判别输入数据是来自真实数据还是生成模型。现在为判别模型提出了一个新的目标函数来判别生成模型产生图像的统计信息是否和真实数据的相似。让 表示判别模型中间层的输出, 新的目标函数被定义为 , 其实就是要求真实图像和合成图像在判别模型中间层的距离要最小。这样可以防止生成模型在当前判别模型上过拟合。

InfoGAN

到这可能有些同学会想到,我要是想通过GAN产生我想要的特定属性的图片改怎么办?普通的GAN输入的是随机的噪声,输出也是与之对应的随机图片,我们并不能控制输出噪声和输出图片的对应关系。这样在训练的过程中也会倒置生成模型倾向于产生更容易欺骗判别模型的某一类特定图片,而不是更好的去学习训练数据的分布,这样对模型的训练肯定是不好的。InfoGAN的提出就是为了解决这一问题,通过对输入噪声添加一些类别信息以及控制图像特征(如mnist数字的角度和厚度)的隐含变量来使得生成模型的输入不在是随机噪声。虽然现在输入不再是随机噪声,但是生成模型可能会忽略这些输入的额外信息还是把输入当成和输出无关的噪声,所以需要定义一个生成模型输入输出的互信息,互信息越高,说明输入输出的关联越大。

下面三张图片展示了通过分别控制输入噪声的类别信息,数字角度信息,数字笔画厚度信息产生指定输出的图片,可以看出InfoGAN产生图片的效果还是很好的。

image

image

image

其他应用

GAN网络还有很多其他的有趣应用,比如下图所示的根据一句话来产生对应的图片,可能大家都有了解karpathy大神的看图说话, 但是GAN有能力把这个过程给反过来。

image

还有下面这个“图像补全”, 根据图像剩余的信息来匹配最佳的补全内容。

image

还有下面这个图像增强的例子,有点去马赛克的意思,效果还是挺不错的:-D。

image

总结

颜乐存说过,2016年深度学习领域最让他兴奋技术莫过于对抗学习。对抗学习确实是解决非监督学习的一个有效方法,而无监督学习一直都是人工智能领域研究者所孜孜追求的“终极目标”之一。

参考

Generative Adversarial Networks

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

Improved Techniques for Training GANs

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

你可能感兴趣的:(一文看懂生成对抗网络)