Gan生成手写数字

一、GAN

生成对抗网络(Generative Adversarial Network,GAN)在2014年被Ian Goodfellow等人首次提出,此后迅速流行,成为热门的深度学习模型。Gan能生成非常逼真的图像、画作、音乐,近些年不乏有利用Gan生成的画作斩获大奖的情况,如前段时间引起争议的由AI生成的《空间歌剧院》在科罗拉多州博览会(Colorado State Fair)的美术比赛中,获得了第一名,如下图所示。
Gan生成手写数字_第1张图片

图1 利用AI生成的画作《空间歌剧院》

尽管借助AI的力量去和人类竞争,目前仍存在很大的争议,但不可否认的是AI已经在这一领域展现了很强大的应用前景。
GAN的原理非常简单,其利用两个模型,其中一个不断生成“假”数据,另外一个模型判断前一个生成的“假”数据,如果能够骗过判别模型,则说明生成了可以以假乱真的数据。

二、GAN的步骤

在GAN中,A是生成器,负责生成“假”数据,B是判别器,负责判断A生成的数据质量,其是一个博弈的过程。
生成器: 接受一个随机噪声向量x作为输入,生成一个张量G(x)
判别器: 接受一个张量作为输入,输出其真假
以图像为例,GAN的整个训练过程如下:
(1) 生成器接受随机噪声,并生成假图像
(2)判别器接受假图像和真图像组合的数据,学习如何判别真假图像
(3)生成器生成新的图像,并使用判别器来判别真假,同时通过判别器结果来判别此次造假的的水平。
(4)重复步骤(1)~(3)

三、生成器

原则上讲生成器并无特定的模型,只要能够生成图像的模型即可,但目前考虑到模型的训练一般选择神经网络,因为可以和判别器一同训练。生成器负责生成一副图片,当然此时的图片为噪声,类似于下图,细看啥也不是,但这不重要,因为GAN中,生成器不需要任何真数据!,是的,你没看错,不管它生成的是什么样的数据,它的老师判别器会告诉他这副图的真假,换句话讲,下图太假了,老师一眼就辨别出来了。
Gan生成手写数字_第2张图片

图2 生成器生成的图片

生成器的代码如下:

import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np
import tqdm
from IPython.display import clear_output

L = keras.layers
LATENT_DIM = 100  # 潜在空间的维度
IMAGE_SHAPE = (28, 28, 1)  # 输出图像的尺寸

# 生成器
generate_net = [
    L.Input(shape=(LATENT_DIM, )),
    L.Dense(256),
    L.LeakyReLU(alpha=0.2),
    L.BatchNormalization(momentum=0.8),
    L.Dense(512),
    L.LeakyReLU(alpha=0.2),
    L.BatchNormalization(momentum=0.8),
    L.Dense(1024),
    L.LeakyReLU(alpha=0.2),
    L.BatchNormalization(momentum=0.8),
    L.Dense(np.prod(IMAGE_SHAPE), activation='tanh'),
    L.Reshape(IMAGE_SHAPE)
]
generate = keras.models.Sequential(generate_net)
generate.summary()

注意
1、我们用到LeakyReLU激活函数,这是GAN中常用的激活函数
2、归一化方式常用的有Batch Normalization(BN),Instance Normalization(IN),Spectral Normalization(SN)
3、生成器的最后一层激活函数一般用tanh函数

四、判别器

GAN中判别器的作用就是判断生成的数据的水平,要判断真假,所以要先训练判别器,类似于你要预测房价,肯定要先去学习(拟合)房价数据,因此,GAN中判别器的训练中一部分是真实数据,这部分考虑到大家下载复制复现代码方便,我们用MNIST数据集,MNIST数据集可以通过tensorflow直接下载,比较方便。有真数据还不行,那必须喂给模型假数据,不然怎么学习真假对吧,那假数据从那来呢?对!生成器,我们生成器不正好可以生成假数据嘛,在这种思路下,我们就可以构造我们的判别器了。
Gan生成手写数字_第3张图片

图2 MNIST数据集

判别器的代码如下:

# 判别器
discriminator_net = [
    L.Input(shape=IMAGE_SHAPE),
    L.Flatten(),
    L.Dense(512),
    L.LeakyReLU(alpha=0.2),
    L.Dense(256),
    L.LeakyReLU(alpha=0.2),
    L.Dense(1, activation='sigmoid')
]
discriminator = keras.models.Sequential(discriminator_net)
discriminator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])
discriminator.summary()

五、生成对抗模型

前面我们已经得知GAN由生成器和判别器两部分构成,且我们已经搭建好了生成器和判别器,把他们一组和即是GAN,但为什么要分开搭建,原因是GAN的训练是不断迭代的一个过程,要分开训练生成器和判别器,生成器生成的好不好要判别器判断,此时应该是冻结判别器的权重,只更新生成器的权重,因为生成器的目标是不断提升“造假”的能力。
GAN的模型如下

adversarial_net = generate_net + discriminator_net
# 冻结判别器的权重
for layer in discriminator_net:
    layer.trainable = False

adversarial = keras.models.Sequential(adversarial_net)
adversarial.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])
adversarial.summary()

六、GAN 的训练

GAN的训练是一个不断调试不断优化的过程,需要大量的经验,推荐一些新手在训练GAN时看一些其他博主建议的小技巧——训练GAN的小技巧

首先要明白,GAN的训练是一个交替迭代的过程

前面我们提到生成器最开始的造假水平很垃圾,只是随便给一副噪声图,这部分假数据我们赋予标签0,代表假数据,然后和真实数据集,也就是MNIST,进行组合,形成一个二分类的图像数据集,来训练我们的判别器,训练结束后我们的判别器已经具备了分辨真假数据的能力(二分类)

判别器训练好后,我们开始训练生成器

生成器我们希望生成非常逼真的数据,但生成器生成的好坏,换句话说生成的数据要让判别器判断为1,也就是以假乱真的水平才行,所以一般来说训练生成器的时候其实是生成器+判别器的串接网络,只不过判别器的权重被冻结,类似于只训练生成器

此时我们用生成器生成一副图片,但我们要把这幅图的标签设置为1,也就是真,别急!!!,你没看错,就是要赋予标签1,这幅图被我们的判别器判断后输出0,也就是假,这样前后形成了非常大的误差,明明是假图,生成器却说真,此时生成器就会拼命的调整参数,直到判别器判断为真!这个过程标准的说法就是反向传播,对生成器网络的参数进行大更新!等到后续生成器能够产生出逼真的图片时,反向传播对生成器的参数就是微调,不断优化的一个过程!

交替训练判别器和生成器即可实现GAN的训练

训练代码如下:

# 数据可视化
def sample_images(batch):
    rows, columns = 3, 10
    sample_count = rows * columns
    plt.figure(figsize=(columns, rows))
    # 使用生成器生成图像
    noise = np.random.normal(0, 1, (sample_count, LATENT_DIM))
    gen_imgs = generate.predict(noise)
    # 生成器图像张量的范围从【-1,1】改为【0,1】
    gen_imgs = 0.5 * gen_imgs + 0.5

    index = 0
    for row in range(rows):
        for column in range(columns):
            image = np.reshape(gen_imgs[index], [28, 28])
            plt.subplot(rows, columns, index+1)
            plt.imshow(image, cmap='gray')
            plt.axis('off')
            index += 1
    plt.tight_layout()
    plt.show()
    return gen_imgs

# 训练
def train(batch=30000, batch_size=32):
    # 读取数据,无需标签
    (image_set, _), (_, _) = keras.datasets.mnist.load_data()
    # 数据归一化
    image_set = image_set / 127.5 - 1.
    # 数据格式转换
    image_set = image_set.reshape(len(image_set), 28, 28, 1)
    # 准备batch_size同样大小的真假标签
    valid = np.ones((batch_size))
    fake = np.zeros((batch_size))
    # 利用tqdm生成迭代器
    batch_list = tqdm.trange(batch)
    for batch in batch_list:
        #  生成器生成图像
        idx = np.random.randint(0, image_set.shape[0], batch_size)
        imgs = image_set[idx]

        # 生成噪声数据并作为生成器的输入
        noise = np.random.normal(0, 1, (batch_size, LATENT_DIM))
        # 使用生成器生成图像
        gen_imgs = generate.predict(noise)

        # 训练判别器
        d_state_real = discriminator.train_on_batch(imgs, valid)
        d_state_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_state = 0.5 * np.add(d_state_real, d_state_fake)

        # 训练生成器
        adv_state = adversarial.train_on_batch(noise, valid)

        # 更新进度条后缀文本,用于输出训练进度
        state = f"[D loss:{d_state[0]:.4f} acc: {d_state[1]:.4f}]" \
                f"[A loss:{adv_state[0]:.4f} acc: {adv_state[1]:.4f}"
        batch_list.set_postfix(state=state)
        if (batch + 1) % 50 == 0:
            clear_output(wait=True)
            _ = sample_images(batch)

train()

训练过程中我们设定每个50个batch输出生成器生成的图片
初始图片
Gan生成手写数字_第4张图片
迭代2000次后生成的图片
Gan生成手写数字_第5张图片
迭代10000次生成的图片
Gan生成手写数字_第6张图片
迭代30000次生成的图片
Gan生成手写数字_第7张图片
整个训练过程的GIF图如下:

七、GAN的注意事项

GAN的训练过程是一个动态过程,每个批次是新的开始,不会有简单的梯度下降过程,而是一个不断对抗平衡的过程,类似于minimax,我们要的是最小的判别器损失,最大的生成器误差,因此GAN的训练需要一些技巧,比如
1、一开始无需分类精度很高的判别器
2、初始学习率要小,否则下降过快或者Max过大不利于GAN的拟合
3、生成器和迭代器无需训练相同次数,比如可以生成器训练1次,判别器训练5次
4、迭代次数需要不断微调,迭代次数过小有可能生成的图像效果一般,迭代次数过大也会导致生成的图像效果一般,很多人会疑问迭代次数过大为什么会导致生成的图像效果一般,因为判别器每次训练更新权重用到的是生成器生成的假数据和真实数据,后期生成器生成的数据已经非常逼真了,而判别器学习到仍然判定为假,因此反而会导致生成器又开始生成很假的数据,如该博主利用GAN生成动漫头像文章点击这,迭代200次的图片如下

当其迭代750次后出现了上面提到的问题,由于生成器生成了图像质量非常高,但判别器仍然判定为假,导致生成器开始产生反向作用,如下所示

八、GAN的评价

GAN的评价一直是一个难题,早期人们通过肉眼判定生成的图像质量,但不可否认的是这种评价方式明显存在缺陷,2016年来,GAN的评价方式开始如雨后春笋般展现,目前比较流行的是:

1、Inception Score

Inception Score(IS)通过利用谷歌图像分类模型nception Net来衡量模型生成图像的清晰度和多样性,Inception Score越高,表示模型越好。

2、Frechet Inception距离

Frechet Inception距离(FID)通过对比真实样本和生成样本在Inception V3模型上的抽象特征的差异来评估生成样本和真实样本的差异,FID越小,表示模型越小。

九、其他

生成对抗网络有很多种,如GAN 、ACGAN、DCGAN、Pix2Pix等。

要查看我的其他博客,点击这里

你可能感兴趣的:(机器学习,生成对抗网络,深度学习,计算机视觉)