生成对抗网络原理与实战

原创:李孟启

1. 前言

在生成对抗网络(Generative Adversarial Network,简称 GAN)发明之前,变分自编码器被认为是理论完备,实现简单,使用神经网络训练起来很稳定,生成的图片逼近度也较高,但是人眼还是可以很轻易地分辨出真实图片与机器生成的图片。

2014 年,Université de Montréal 大学 Yoshua Bengio(2019 年图灵奖获得者)的学生 Ian Goodfellow 提出了生成对抗网络 GAN,从而开辟了深度学习最炙手可热的研究方向之一。从 2014 年到 2019 年,GAN 的研究稳步推进,研究捷报频传,最新的 GAN 算法在图片生成上的效果甚至达到了肉眼难辨的程度,着实令人振奋。由于 GAN 的发明,Ian Goodfellow 荣获 GAN 之父称号,并获得 2017 年麻省理工科技评论颁发的 35 Innovators Under 35 奖项。图 1 展示了从 2014 年到 2018 年,GAN 模型取得了图书生成的效果,可以看到不管是图片大小,还是图片逼真度,都有了巨大的提升。

生成对抗网络原理与实战_第1张图片

图1 GAN模型2014~2018年的图片生成效果

2. 博弈学实例

接下来,我们将从生活中博弈学习的实例出发,一步步引出 GAN 算法的设计思想和模型结构。我们用一个漫画家的成长轨迹来形象介绍生成对抗网络的思想。考虑一对双胞胎兄弟,分别称为老二 G 和老大 D,G 学习如何绘制漫画,D 学习如何鉴赏画作。还在娃娃时代的两兄弟,尚且只学会了如何使用画笔和纸张,G 绘制了一张不明所以的画作,如图2(a)所示,由于此时 D 鉴别能力不高,觉得 G 的作品还行,但是人物主体不够鲜明。在 D 的指引和鼓励下,G 开始尝试学习如何绘制主体轮廓和使用简单的色彩搭配。一年后,G 提升了绘画的基本功,D 也通过分析名作和初学者 G 的作品,初步掌握了鉴别作品的能力。此时 D 觉得 G 的作品人物主体有了,如图 2(b),但是色彩的运用还不够成熟。数年后,G 的绘画基本功已经很扎实了,可以轻松绘制出主体鲜明、颜色搭配合适和逼真度较高的画作,如图 2©,但是 D 同样通过观察 G 和其它名作的差别,提升了画作鉴别能力,觉得 G 的画作技艺已经趋于成熟,但是对生活的观察尚且不够,作品没有传达神情且部分细节不够完美。又过了数年,G 的绘画功力达到了炉火纯青的地步,绘制的作品细节完美、风格迥异、惟妙惟肖,宛如大师级水准,如图 2(d),即便此时的D 鉴别功力也相当出色,亦很难将 G 和其他大师级的作品区分开来。

上述画家的成长历程其实是一个生活中普遍存在的学习过程,通过双方的博弈学习,相互提高,最终达到一个平衡点。GAN 网络借鉴了博弈学习的思想,分别设立了两个子网络:负责生成样本的生成器 G 和负责鉴别真伪的鉴别器 D。类比到画家的例子,生成器 G就是老二,鉴别器 D 就是老大。鉴别器 D 通过观察真实的样本和生成器 G 产生的样本之间的区别,学会如何鉴别真假,其中真实的样本为真,生成器 G 产生的样本为假。而生成器 G 同样也在学习,它希望产生的样本能够获得鉴别器 D 的认可,即在鉴别器 D 中鉴别为真,因此生成器 G 通过优化自身的参数,尝试使得自己产生的样本在鉴别器 D 中判别为真。生成器 G 和鉴别器 D 相互博弈,共同提升,直至达到平衡点。此时生成器 G 生成的样本非常逼真,使得鉴别器 D 真假难分。

生成对抗网络原理与实战_第2张图片
图2 画家的成长轨迹示意图

在原始的 GAN 论文中,Ian Goodfellow 使用了另一个形象的比喻来介绍 GAN 模型:生成器网络 G 的功能就是产生一系列非常逼真的假钞试图欺骗鉴别器 D,而鉴别器 D 通过学习真钞和生成器 G 生成的假钞来掌握钞票的鉴别方法。这两个网络在相互博弈的过程中间同步提升,直到生成器 G 产生的假钞非常的逼真,连鉴别器 D 都真假难辨。

这种博弈学习的思想使得 GAN 的网络结构和训练过程与之前的网络模型略有不同,下面我们来详细介绍 GAN 的网络结构和算法原理。

3. GAN原理

一个典型的生成对抗网络模型大概如图3所示。

生成对抗网络原理与实战_第3张图片
图3 对抗生成网络模型

我们先来理解下GAN的两个模型要做什么。首先判别模型(鉴别器),就是图3中右半部分的网络,直观来看就是一个简单的神经网络结构,输入就是一副图像,输出就是一个概率值(其实是个二分类问题),用于判断真假使用(概率值大于0.5那就是真,小于0.5那就是假),真假也不过是人们定义的概率而已。其次是生成模型,生成模型要做什么呢,同样也可以看成是一个神经网络模型,输入是一组随机数Z,输出是一个图像,不再是一个数值而已。从图3中可以看到,会存在两个数据集,一个是真实数据集,另一个是假的数据集,那这个数据集就是有生成网络造出来的数据集。根据图3我们再来理解一下GAN的目标是要做什么:

判别网络的目的:就是能判别出来输入的一张图它是来自真实样本集还是假样本集。假如输入的是真样本,网络输出就接近1,输入的是假样本,网络输出接近0,那么很完美,达到了很好判别的目的。

生成网络的目的:生成网络是造样本的,它的目的就是使得自己造样本的能力尽可能强,强到什么程度呢,你判别网络没法判断我是真样本还是假样本。因此辨别网络的作用就是对噪音生成的数据辨别他为假的,对真实的数据辨别他为真的。而生成网络的损失函数就是使得对于噪音数据,经过辨别网络之后的辨别结果是真的,这样就能达到生成真实图像的目的。这里会感觉比较饶,这也是生成对抗网络的难点所在,理解了这点,整个生成对抗网络模型也就理解了。

4.DCGAN实战

这里我们拿DCGAN来举例子,DCGAN是GAN的一个变体,DCGAN就是将CNN和原始的GAN结合到一起,生成网络和鉴别网络都运用到了深度卷积神经网络。DCGAN提高了基础GAN的稳定性和生成结果质量。

该项目使用的是mnist手写字数据集,深度学习框架为tensorflow。你也可以直接跳过下面代码直接git clone本项目,项目的github链接https://github.com/limengqigithub/DCGAN-mnist-master.git。

4.1 DCGAN模型代码

import tensorflow as tf
from tensorflow import keras

# 生成网络
class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()

        self.n_f = 512
        self.n_k = 4

        # input z vector is [None, 100]
        self.dense1 = keras.layers.Dense(3 * 3 * self.n_f)
        self.conv2 = keras.layers.Conv2DTranspose(self.n_f // 2, 3, 2, 'valid')
        self.bn2 = keras.layers.BatchNormalization()
        self.conv3 = keras.layers.Conv2DTranspose(self.n_f // 4, self.n_k, 2, 'same')
        self.bn3 = keras.layers.BatchNormalization()
        self.conv4 = keras.layers.Conv2DTranspose(1, self.n_k, 2, 'same')
        return

    def call(self, inputs, training=None):
        # [b, 100] => [b, 3, 3, 512]
        x = tf.nn.leaky_relu(tf.reshape(

你可能感兴趣的:(生成对抗网络,人工智能,深度学习,cnn,神经网络)