自打关注深度学习这个领域就不时的看到和 Generative Adversarial Network, GAN 相关的东西,也一直非常好奇这个被 LeCun 称为深度学习近年来最大的突破的东西到底是什么样子的。正好在 Udacity 的课堂里遇到了,在完成了通过 GAN 来完成人脸生成的项目后,在这里做一个总结,加深一下对于 GAN 这个网络的理解。为了便于本地试验,这里展示的是利用 MNIST 数据集来训练一个简单的 GAN 来生成手写数字的过程。注意文中代码和示例图片来自 Udacity 深度学习纳米学位课程,版权归 Udacity 所有。
深度神经网络最令人诟病一点就在于其决策过程的不可解释性,你无从知道网络中的单元提取了哪些特征来完成了一项分类或识别任务。比如在图片识别任务中,即便你可以提取隐藏层的 feature map 来可视化出来相应层的情况,其图像在人类看来是抽象而诡异甚至有些惊悚的。这一点其实在我看来是十分正常的,也不应该像很多媒体的解读方式那样过分的夸大,事实上,人脑的加工过程有谁可以可视化出来呢?只不过我们对于人类行为的可预测性是有把握的,所以不像对于新生技术那样容易催生恐惧。
而 GAN 最为聪明之处在于既然人类无法理解网络内部的生成过程,索性不用人脑和人类对于图像的理解方式去理解中间过程,而是用另一个类似结构的神经网络,二者的相互理解过程也就是对抗 Adversarial 的过程。其实现的大致思路是:
作为生成器的一个典型代表,GAN 的一个典型应用是通过模型来生成类似已有数据集的图片来实现数据扩增,因此可以首先建立一个通过多层神经网络实现的生成器,其主要作用是通过对于符合一定分布规律的原始数据进行处理,进而得到一个符合另一特定分布情况的结果图像。这里要求这个网络至少包含一个隐藏层,否则网络就不具有足够的学习和泛化能力,这个网络在 GAN 中被称为生成器 Generator。例如在下面的示例图片中,生成器的输入是符合某个分布特征的随机数字:在后续的代码示例中采用的是 (-1, 1) 之间的均匀分布
在获得了生成器之后,还要建立一个类似结构的可以完成图像识别任务的分类器,其特殊之处在于这个网络的输出层只对输入是来自原始数据集还是由生成器网络生成的结果做一个真假判断,这个网络在 GAN 中称为识别器 Discriminator
在看到代码之前我一直以为 GAN 的实现会比较复杂,但真正看到代码之后就像看到 E = mc2 一样,发现其是如此的简洁,优雅,直观,不得不佩服 Ian Goodfellow 强大的思路。闲话到此为止,网络架构和实现代码如下:
%matplotlib inline
import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# load data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')
# define the model input for both Generator and Discirminator
def model_inputs(real_dim, z_dim):
inputs_real = tf.placeholder(tf.float32, (None, real_dim), name='input_real')
inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
return inputs_real, inputs_z
# define the Generator
def generator(z, out_dim, n_units=128, reuse=False, alpha=0.01):
with tf.variable_scope('generator', reuse=reuse):
# Hidden layer
h1 = tf.layers.dense(z, n_units, activation=None)
# Leaky ReLU
h1 = tf.maximum(alpha * h1, h1)
# Logits and tanh output
logits = tf.layers.dense(h1, out_dim, activation=None)
out = tf.tanh(logits)
return out
# define the Discriminator
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
with tf.variable_scope('discriminator', reuse=reuse):
# Hidden layer
h1 = tf.layers.dense(x, n_units, activation=None)
# Leaky ReLU
h1 = tf.maximum(alpha * h1, h1)
logits = tf.layers.dense(h1, 1, activation=None)
out = tf.sigmoid(logits)
return out, logits
这里之所以要定义这个 variable_scope 是由于在后续的训练中,需要分别更新生成器和判别器的参数,为了提取参数而特别设置的。另外值得注意的是,激活函数需要采用 Leaky ReLU 来保证梯度可以从判别器传回到生成器。
# build the network
tf.reset_default_graph()
# Create our input placeholders
input_real, input_z = model_inputs(input_size, z_size)
# Build the model
g_model = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
# g_model is the generator output
d_model_real, d_logits_real = discriminator(input_real, n_units=d_hidden_size, alpha=alpha)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, n_units=d_hidden_size, alpha=alpha)
# Calculate losses
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_logits_real) * (1 - smooth)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_logits_real)))
d_loss = d_loss_real + d_loss_fake
g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_logits_fake)))
在这里新引入的一个操作是 label smoothing,其目的在于适度的放低要求以促进收敛。而针对损失函数这部分,由于希望判别器将真实数据识别为 1, 而将生成器生成的数据识别为 0,因此需要分别计算这两部分的损失函数。
# Optimizers
learning_rate = 0.002
# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith('generator')]
d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
这一段代码非常重要,正式因为选择了间歇性的训练才使得网络的对抗得以实现。
# Size of input image to discriminator
input_size = 784
# Size of latent vector to generator
z_size = 100
# Sizes of hidden layers in generator and discriminator
g_hidden_size = 128
d_hidden_size = 128
# Leak factor for leaky ReLU
alpha = 0.01
# Smoothing
smooth = 0.1
下面代码部分为比较常见的训练代码结构:
batch_size = 100
epochs = 100
samples = []
losses = []
# Only save generator variables
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for e in range(epochs):
for ii in range(mnist.train.num_examples//batch_size):
batch = mnist.train.next_batch(batch_size)
# Get images, reshape and rescale to pass to D
batch_images = batch[0].reshape((batch_size, 784))
batch_images = batch_images*2 - 1
# Sample random noise for G
batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
# Run optimizers
_ = sess.run(d_train_opt, feed_dict={input_real: batch_images, input_z: batch_z})
_ = sess.run(g_train_opt, feed_dict={input_z: batch_z})
# At the end of each epoch, get the losses and print them out
train_loss_d = sess.run(d_loss, {input_z: batch_z, input_real: batch_images})
train_loss_g = g_loss.eval({input_z: batch_z})
print("Epoch {}/{}...".format(e+1, epochs),
"Discriminator Loss: {:.4f}...".format(train_loss_d),
"Generator Loss: {:.4f}".format(train_loss_g))
# Save losses to view after training
losses.append((train_loss_d, train_loss_g))
# Sample from generator as we're training for viewing afterwards
sample_z = np.random.uniform(-1, 1, size=(16, z_size))
gen_samples = sess.run(
generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
feed_dict={input_z: sample_z})
samples.append(gen_samples)
saver.save(sess, './checkpoints/generator.ckpt')
# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
pkl.dump(samples, f)
为了监控训练,可以提取训练过程中的参数来识别训练结果。实际上在学习过程中可以发现 GAN 的训练对于超参数的选择十分敏感,并且在后续的 DCGAN 学习中,作者们甚至通过调整 Adam 中的指数加权平均参数 beta1
来实现较好的训练效果。Ian Goodfellow 在 Andrew Ng 的访谈里也提到自己现在 40% 的时间话在研究如何 Stablize GAN,当时没理解是什么意思,直到自己训练了 DCGAN 之后才知道原来 GAN 的训练对于超参数是如此的敏感。
def view_samples(epoch, samples):
fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
for ax, img in zip(axes.flatten(), samples[epoch]):
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
return fig, axes
rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)
for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
ax.imshow(img.reshape((28,28)), cmap='Greys_r')
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
参考阅读
Tips and tricks to make GANs work
Generative Adversarial Networks for beginners