生成对抗网络(Generative Adversarial Network,GAN)在2014年被Ian Goodfellow等人首次提出,此后迅速流行,成为热门的深度学习模型。Gan能生成非常逼真的图像、画作、音乐,近些年不乏有利用Gan生成的画作斩获大奖的情况,如前段时间引起争议的由AI生成的《空间歌剧院》在科罗拉多州博览会(Colorado State Fair)的美术比赛中,获得了第一名,如下图所示。
图1 利用AI生成的画作《空间歌剧院》 |
---|
尽管借助AI的力量去和人类竞争,目前仍存在很大的争议,但不可否认的是AI已经在这一领域展现了很强大的应用前景。
GAN的原理非常简单,其利用两个模型,其中一个不断生成“假”数据,另外一个模型判断前一个生成的“假”数据,如果能够骗过判别模型,则说明生成了可以以假乱真的数据。
在GAN中,A是生成器,负责生成“假”数据,B是判别器,负责判断A生成的数据质量,其是一个博弈的过程。
生成器: 接受一个随机噪声向量x作为输入,生成一个张量G(x)
判别器: 接受一个张量作为输入,输出其真假
以图像为例,GAN的整个训练过程如下:
(1) 生成器接受随机噪声,并生成假图像
(2)判别器接受假图像和真图像组合的数据,学习如何判别真假图像
(3)生成器生成新的图像,并使用判别器来判别真假,同时通过判别器结果来判别此次造假的的水平。
(4)重复步骤(1)~(3)
原则上讲生成器并无特定的模型,只要能够生成图像的模型即可,但目前考虑到模型的训练一般选择神经网络,因为可以和判别器一同训练。生成器负责生成一副图片,当然此时的图片为噪声,类似于下图,细看啥也不是,但这不重要,因为GAN中,生成器不需要任何真数据!,是的,你没看错,不管它生成的是什么样的数据,它的老师判别器会告诉他这副图的真假,换句话讲,下图太假了,老师一眼就辨别出来了。
图2 生成器生成的图片 |
---|
生成器的代码如下:
import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np
import tqdm
from IPython.display import clear_output
L = keras.layers
LATENT_DIM = 100 # 潜在空间的维度
IMAGE_SHAPE = (28, 28, 1) # 输出图像的尺寸
# 生成器
generate_net = [
L.Input(shape=(LATENT_DIM, )),
L.Dense(256),
L.LeakyReLU(alpha=0.2),
L.BatchNormalization(momentum=0.8),
L.Dense(512),
L.LeakyReLU(alpha=0.2),
L.BatchNormalization(momentum=0.8),
L.Dense(1024),
L.LeakyReLU(alpha=0.2),
L.BatchNormalization(momentum=0.8),
L.Dense(np.prod(IMAGE_SHAPE), activation='tanh'),
L.Reshape(IMAGE_SHAPE)
]
generate = keras.models.Sequential(generate_net)
generate.summary()
注意
1、我们用到LeakyReLU激活函数,这是GAN中常用的激活函数
2、归一化方式常用的有Batch Normalization(BN),Instance Normalization(IN),Spectral Normalization(SN)
3、生成器的最后一层激活函数一般用tanh函数
GAN中判别器的作用就是判断生成的数据的水平,要判断真假,所以要先训练判别器,类似于你要预测房价,肯定要先去学习(拟合)房价数据,因此,GAN中判别器的训练中一部分是真实数据,这部分考虑到大家下载复制复现代码方便,我们用MNIST数据集,MNIST数据集可以通过tensorflow直接下载,比较方便。有真数据还不行,那必须喂给模型假数据,不然怎么学习真假对吧,那假数据从那来呢?对!生成器,我们生成器不正好可以生成假数据嘛,在这种思路下,我们就可以构造我们的判别器了。
图2 MNIST数据集 |
---|
判别器的代码如下:
# 判别器
discriminator_net = [
L.Input(shape=IMAGE_SHAPE),
L.Flatten(),
L.Dense(512),
L.LeakyReLU(alpha=0.2),
L.Dense(256),
L.LeakyReLU(alpha=0.2),
L.Dense(1, activation='sigmoid')
]
discriminator = keras.models.Sequential(discriminator_net)
discriminator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])
discriminator.summary()
前面我们已经得知GAN由生成器和判别器两部分构成,且我们已经搭建好了生成器和判别器,把他们一组和即是GAN,但为什么要分开搭建,原因是GAN的训练是不断迭代的一个过程,要分开训练生成器和判别器,生成器生成的好不好要判别器判断,此时应该是冻结判别器的权重,只更新生成器的权重,因为生成器的目标是不断提升“造假”的能力。
GAN的模型如下
adversarial_net = generate_net + discriminator_net
# 冻结判别器的权重
for layer in discriminator_net:
layer.trainable = False
adversarial = keras.models.Sequential(adversarial_net)
adversarial.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])
adversarial.summary()
GAN的训练是一个不断调试不断优化的过程,需要大量的经验,推荐一些新手在训练GAN时看一些其他博主建议的小技巧——训练GAN的小技巧
训练代码如下:
# 数据可视化
def sample_images(batch):
rows, columns = 3, 10
sample_count = rows * columns
plt.figure(figsize=(columns, rows))
# 使用生成器生成图像
noise = np.random.normal(0, 1, (sample_count, LATENT_DIM))
gen_imgs = generate.predict(noise)
# 生成器图像张量的范围从【-1,1】改为【0,1】
gen_imgs = 0.5 * gen_imgs + 0.5
index = 0
for row in range(rows):
for column in range(columns):
image = np.reshape(gen_imgs[index], [28, 28])
plt.subplot(rows, columns, index+1)
plt.imshow(image, cmap='gray')
plt.axis('off')
index += 1
plt.tight_layout()
plt.show()
return gen_imgs
# 训练
def train(batch=30000, batch_size=32):
# 读取数据,无需标签
(image_set, _), (_, _) = keras.datasets.mnist.load_data()
# 数据归一化
image_set = image_set / 127.5 - 1.
# 数据格式转换
image_set = image_set.reshape(len(image_set), 28, 28, 1)
# 准备batch_size同样大小的真假标签
valid = np.ones((batch_size))
fake = np.zeros((batch_size))
# 利用tqdm生成迭代器
batch_list = tqdm.trange(batch)
for batch in batch_list:
# 生成器生成图像
idx = np.random.randint(0, image_set.shape[0], batch_size)
imgs = image_set[idx]
# 生成噪声数据并作为生成器的输入
noise = np.random.normal(0, 1, (batch_size, LATENT_DIM))
# 使用生成器生成图像
gen_imgs = generate.predict(noise)
# 训练判别器
d_state_real = discriminator.train_on_batch(imgs, valid)
d_state_fake = discriminator.train_on_batch(gen_imgs, fake)
d_state = 0.5 * np.add(d_state_real, d_state_fake)
# 训练生成器
adv_state = adversarial.train_on_batch(noise, valid)
# 更新进度条后缀文本,用于输出训练进度
state = f"[D loss:{d_state[0]:.4f} acc: {d_state[1]:.4f}]" \
f"[A loss:{adv_state[0]:.4f} acc: {adv_state[1]:.4f}"
batch_list.set_postfix(state=state)
if (batch + 1) % 50 == 0:
clear_output(wait=True)
_ = sample_images(batch)
train()
训练过程中我们设定每个50个batch输出生成器生成的图片
初始图片
迭代2000次后生成的图片
迭代10000次生成的图片
迭代30000次生成的图片
整个训练过程的GIF图如下:
GAN的训练过程是一个动态过程,每个批次是新的开始,不会有简单的梯度下降过程,而是一个不断对抗平衡的过程,类似于minimax,我们要的是最小的判别器损失,最大的生成器误差,因此GAN的训练需要一些技巧,比如
1、一开始无需分类精度很高的判别器
2、初始学习率要小,否则下降过快或者Max过大不利于GAN的拟合
3、生成器和迭代器无需训练相同次数,比如可以生成器训练1次,判别器训练5次
4、迭代次数需要不断微调,迭代次数过小有可能生成的图像效果一般,迭代次数过大也会导致生成的图像效果一般,很多人会疑问迭代次数过大为什么会导致生成的图像效果一般,因为判别器每次训练更新权重用到的是生成器生成的假数据和真实数据,后期生成器生成的数据已经非常逼真了,而判别器学习到仍然判定为假,因此反而会导致生成器又开始生成很假的数据,如该博主利用GAN生成动漫头像文章点击这,迭代200次的图片如下
当其迭代750次后出现了上面提到的问题,由于生成器生成了图像质量非常高,但判别器仍然判定为假,导致生成器开始产生反向作用,如下所示
GAN的评价一直是一个难题,早期人们通过肉眼判定生成的图像质量,但不可否认的是这种评价方式明显存在缺陷,2016年来,GAN的评价方式开始如雨后春笋般展现,目前比较流行的是:
Inception Score(IS)通过利用谷歌图像分类模型nception Net来衡量模型生成图像的清晰度和多样性,Inception Score越高,表示模型越好。
Frechet Inception距离(FID)通过对比真实样本和生成样本在Inception V3模型上的抽象特征的差异来评估生成样本和真实样本的差异,FID越小,表示模型越小。
要查看我的其他博客,点击这里