关于TFGAN、GAN的原理以及Unconditional GAN都已经在之前的文章:简单易用的轻量级生成对抗网络工具库:TFGAN中说明了,本文内容主要是使用TFGAN实现Conditional GAN模型。
环境
- Python 3.6
- Tensorflow-gpu 1.8.0
Conditional GAN
经典的非条件GAN(Unconditional GAN)是从噪声分布中随机生成我们需要的数据,但是我们无法控制生成的数据属于哪一类,而条件GAN(Conditional GAN)就是用来解决这一个问题的。
CGAN中所谓条件就是指我们现在生成的网络不仅仅需要逼真,而且还要有一定的条件。如下图所示,Generator和Discriminator的输入不仅包含了随机噪声,还包含了指定类别的one-hot编码,通过这样的方式我们就可使让生成器生成我们指定的类别数据。
实现
CGAN和UGAN的网络结构基本一致,主要区别就在于输入中增加了类别的one-hot编码。
Generator
如下所示,模型的输入inputs
是一个元组(noise, one_hot_labels)
,通过tfgan.features.condition_tensor_from_onehot
函数将这两个输入连接后送入生成器。
def conditional_generator(inputs, weight_decay=2.5e-5, is_training=True):
"""Simple generator to produce MNIST images.
Args:
noise: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
conditional generator.
weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
A generated image in the range [-1, 1].
"""
noise, one_hot_labels = inputs
noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)
with tf.contrib.framework.arg_scope(
[layers.fully_connected, layers.conv2d_transpose],
activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
weights_regularizer=layers.l2_regularizer(weight_decay)):
with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
zero_debias_moving_mean=True):
net = layers.fully_connected(noise, 1024)
net = layers.fully_connected(net, 7 * 7 * 128)
net = tf.reshape(net, [-1, 7, 7, 128])
net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
# Make sure that generator output is in the same range as `inputs`
# ie [-1, 1].
net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)
return net
Discriminator
如下所示,模型的输入除了img
还有其对应的one_hot_labels
,通过tfgan.features.condition_tensor_from_onehot
函数将图像的分类特征与类别编码连接起来,进行最后的判别。
def conditional_discriminator(img, conditioning, weight_decay=2.5e-5,
is_training=True):
"""Discriminator network on MNIST digits.
Args:
img: Real or generated MNIST digits. Should be in the range [-1, 1].
conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
weight_decay: The L2 weight decay.
weight_decay: The L2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Logits for the probability that the image is real.
"""
_, one_hot_labels = conditioning
with tf.contrib.framework.arg_scope(
[layers.conv2d, layers.fully_connected],
activation_fn=tf.nn.relu, normalizer_fn=None,
weights_regularizer=layers.l2_regularizer(weight_decay),
biases_regularizer=layers.l2_regularizer(weight_decay)):
net = layers.conv2d(img, 64, [4, 4], stride=2)
net = layers.conv2d(net, 128, [4, 4], stride=2)
net = layers.flatten(net)
net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels)
with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)
return layers.linear(net, 1)
实验
CGAN的模型图如下所示:
生成结果如下所示,经过2w次训练,模型基本能够生成我们指定的数据。
Loss:
生成效果:
尝试过使用1e-4的生成器学习率,能够在1000步左右就产生正确的类别结果,但是生成数据清晰度不够,继续训练会发生model collapse。考虑训练前1000步使用1e-4,后面改用1e-5收敛效果会更好。
完整的CGAN代码如下所示:
import tensorflow as tf
import tensorflow.contrib.gan as tfgan
import tensorflow.contrib.layers as layers
from tensorflow.examples.tutorials.mnist import input_data
def float_image_to_uint8(image):
"""Convert float image in [-1, 1) to [0, 255] uint8.
Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
Args:
image: An image tensor. Values should be in [-1, 1).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 128.0) + 128.0
return tf.cast(image, tf.uint8)
def provide_data(batch_size, num_threads=1):
file = "MNIST"
# range 0~1
mnist = input_data.read_data_sets(file, one_hot=True)
train_data = mnist.train.images.reshape(-1, 28, 28, 1) * 255
train_labels = mnist.train.labels
# transfer to -1~1
train_data = (tf.to_float(train_data) - 128.0) / 128.0
# Creates a QueueRunner for the pre-fetching operation.
input_queue = tf.train.slice_input_producer([train_data, train_labels], shuffle=True)
images, labels = tf.train.batch(
input_queue,
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size)
return images, labels
def conditional_generator(inputs, weight_decay=2.5e-5, is_training=True):
"""Simple generator to produce MNIST images.
Args:
noise: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
conditional generator.
weight_decay: The value of the l2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
A generated image in the range [-1, 1].
"""
noise, one_hot_labels = inputs
noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)
with tf.contrib.framework.arg_scope(
[layers.fully_connected, layers.conv2d_transpose],
activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
weights_regularizer=layers.l2_regularizer(weight_decay)):
with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
zero_debias_moving_mean=True):
net = layers.fully_connected(noise, 1024)
net = layers.fully_connected(net, 7 * 7 * 128)
net = tf.reshape(net, [-1, 7, 7, 128])
net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
# Make sure that generator output is in the same range as `inputs`
# ie [-1, 1].
net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)
return net
def conditional_discriminator(img, conditioning, weight_decay=2.5e-5,
is_training=True):
"""Discriminator network on MNIST digits.
Args:
img: Real or generated MNIST digits. Should be in the range [-1, 1].
conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
weight_decay: The L2 weight decay.
weight_decay: The L2 weight decay.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
Returns:
Logits for the probability that the image is real.
"""
_, one_hot_labels = conditioning
with tf.contrib.framework.arg_scope(
[layers.conv2d, layers.fully_connected],
activation_fn=tf.nn.relu, normalizer_fn=None,
weights_regularizer=layers.l2_regularizer(weight_decay),
biases_regularizer=layers.l2_regularizer(weight_decay)):
net = layers.conv2d(img, 64, [4, 4], stride=2)
net = layers.conv2d(net, 128, [4, 4], stride=2)
net = layers.flatten(net)
net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels)
with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)
return layers.linear(net, 1)
def train(batch_size, max_steps, gen_lr, dis_lr, train_log_dir):
tf.reset_default_graph()
if not tf.gfile.Exists(train_log_dir):
tf.gfile.MakeDirs(train_log_dir)
# Set up the input.
images, one_hot_labels = provide_data(batch_size)
noise = tf.random_normal([batch_size, 64])
with tf.name_scope('model'):
# Build the generator and discriminator.
gan_model = tfgan.gan_model(
generator_fn=conditional_generator, # you define
discriminator_fn=conditional_discriminator, # you define
real_data=images,
generator_inputs=(noise, one_hot_labels))
with tf.name_scope('loss'):
# Build the GAN loss.
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
gradient_penalty_weight=1.0,
add_summaries=True)
with tf.name_scope('train'):
# Create the train ops, which calculate gradients and apply updates to weights.
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
check_for_unused_update_ops=False,
summarize_gradients=True,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
# Run the train ops in the alternating training scheme.
tfgan.gan_train(
train_ops,
hooks=[tf.train.StopAtStepHook(num_steps=max_steps)],
logdir=train_log_dir,
save_summaries_steps=10)
def test(eval_dir, checkpoint_dir):
tf.reset_default_graph()
if not tf.gfile.Exists(eval_dir):
tf.gfile.MakeDirs(eval_dir)
noises = tf.random_normal([100, 64])
c = [i for i in range(10) for j in range(10)]
conditions = tf.one_hot(c, 10)
random_inputs = (noises, conditions)
with tf.variable_scope('Generator'):
images = conditional_generator(random_inputs, is_training=False)
reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10)
uint8_images = float_image_to_uint8(reshaped_images)
image_write_ops = tf.write_file(
'%s/%s' % (eval_dir, 'conditional_gan.png'),
tf.image.encode_png(uint8_images[0]))
tf.contrib.training.evaluate_repeatedly(
checkpoint_dir,
eval_ops=image_write_ops,
hooks=[tf.contrib.training.StopAfterNEvalsHook(1)],
max_number_of_evaluations=1)
if __name__ == '__main__':
train(14, 10000, 1e-5, 1e-4, 'cg_logs/')
test('cg_eval/', 'cg_logs/')