GAN的原始模型有很多可以改进的缺点,首当其中就是“模型不可控”。从上面对GAN的介绍能够看出,模型以一个随机噪声为输入。显然,我们很难对输出的结构进行控制。例如,使用纯粹的GAN,我们可以训练出一个生成器:输入随机噪声,产生一张写着0-9某一个数字的图片。然而,在现实应用中,我们往往想要生成“指定”的一张图片。
在GAN上增加一个额外的输入。也就是说,以前我们的生成模型是,现在,我们的生成模型是在一个条件c的控制下产生。而这个c就是我们用来控制模型的额外的输入。
c可以是表示我们意图的一串编码,例如我们想要做0-9的手写数字生成,则c可以是一个10维的one-hot向量。则在训练过程中,我们将这些label加入到训练数据中,从而得到一个按照我们需求产生图片的生成器。
这就是Conditional Generative Adversarial Nets最基本的想法。这里要注意的是,这个c不但附加在了生成器上,同时也附加在了判别器上,相当于给了判别器一个额外的信息:现在这个图片是以条件c生成的?还是以条件c控制下的真正的图片?
原文中有这样一张图,在其他博客中也常见到
对于GAN来说,我们训练的目标是:
而对于Conditional的GAN来说,训练目标只需要变成:
(原文中的公式有误,后面一项的判别器D中忘了加以y为条件的概率)
其实这个改动形象一些表示就是将原来只接受一个输入z的生成器变成接受两个输入(z和y),将原来只接受一个输入x的判别器变成接受两个输入(x和y)。
CGAN代码如下:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
#数据输入
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
#返回随机值
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size, stddev=xavier_stddev)
#X代表输入图片,应该是28*28,但是这里没有使用CNN,y是相应的label
""" Discriminator Net model """
X = tf.placeholder(tf.float32, shape=[None, 784])
y = tf.placeholder(tf.float32, shape=[None, y_dim])
#权重,CGAN的输入是将图片输入与label concat起来,所以权重维度为784+10
D_W1 = tf.Variable(xavier_init([X_dim + y_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
#第二层有h_dim个节点
D_W2 = tf.Variable(xavier_init([h_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))
theta_D = [D_W1, D_W2, D_b1, D_b2]
#D网络,这里是一个简单的神经网络,x是输入图片向量,y是相应的label
def discriminator(x, y):
inputs = tf.concat(axis=1, values=[x, y])
D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit
#G网络参数,输入维度为Z_dim+y_dim,中间层有h_dim个节点,输出X_dim的数据
""" Generator Net model """
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])
#权重
G_W1 = tf.Variable(xavier_init([Z_dim + y_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
theta_G = [G_W1, G_W2, G_b1, G_b2]
#G网络
def generator(z, y):
inputs = tf.concat(axis=1, values=[z, y])
G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob)
return G_prob
#噪声产生的函数
def sample_Z(m, n):
return np.random.uniform(-1., 1., size=[m, n])
def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig
#生成网络,基本和GAN一致
G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y)
D_fake, D_logit_fake = discriminator(G_sample, y)
#优化式
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
#训练
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#输出图片在out文件夹
if not os.path.exists('out/'):
os.makedirs('out/')
i = 0
for it in range(1000000):
if it % 1000 == 0:
#n_sample 是G网络测试用的Batchsize,为16,所以输出的png图有16张
n_sample = 16
Z_sample = sample_Z(n_sample, Z_dim)#输入的噪声,尺寸为batchsize*noise维度
y_sample = np.zeros(shape=[n_sample, y_dim])#输入的label,尺寸为batchsize*label维度
y_sample[:, 7] = 1 #输出7
samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})#G网络的输入
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')#输出生成的图片
i += 1
plt.close(fig)
#mb_size是网络训练时用的Batchsize,为100
X_mb, y_mb = mnist.train.next_batch(mb_size)
#Z_dim是noise的维度,为100
Z_sample = sample_Z(mb_size, Z_dim)
#交替最小化训练
_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: Z_sample, y:y_mb})
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: Z_sample, y:y_mb})
#输出训练时的参数
if it % 1000 == 0:
print('Iter: {}'.format(it))
print('D loss: {:.4}'. format(D_loss_curr))
print('G_loss: {:.4}'.format(G_loss_curr))
print()
生成效果如下:
为了方便理解,本文只用了最简单的神经网络,有时间会使用CNN重写该网络。