TFGAN实现Conditional GAN

关于TFGAN、GAN的原理以及Unconditional GAN都已经在之前的文章:简单易用的轻量级生成对抗网络工具库:TFGAN中说明了,本文内容主要是使用TFGAN实现Conditional GAN模型。

环境

  • Python 3.6
  • Tensorflow-gpu 1.8.0

Conditional GAN

经典的非条件GAN(Unconditional GAN)是从噪声分布中随机生成我们需要的数据,但是我们无法控制生成的数据属于哪一类,而条件GAN(Conditional GAN)就是用来解决这一个问题的。

CGAN中所谓条件就是指我们现在生成的网络不仅仅需要逼真,而且还要有一定的条件。如下图所示,Generator和Discriminator的输入不仅包含了随机噪声,还包含了指定类别的one-hot编码,通过这样的方式我们就可使让生成器生成我们指定的类别数据。

CGAN

实现

CGAN和UGAN的网络结构基本一致,主要区别就在于输入中增加了类别的one-hot编码。

Generator

如下所示,模型的输入inputs是一个元组(noise, one_hot_labels),通过tfgan.features.condition_tensor_from_onehot函数将这两个输入连接后送入生成器。

def conditional_generator(inputs, weight_decay=2.5e-5, is_training=True):
    """Simple generator to produce MNIST images.

    Args:
        noise: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
                conditional generator.
        weight_decay: The value of the l2 weight decay.
        is_training: If `True`, batch norm uses batch statistics. If `False`, batch
            norm uses the exponential moving average collected from population 
            statistics.

    Returns:
        A generated image in the range [-1, 1].
    """
    noise, one_hot_labels = inputs
    noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)

    with tf.contrib.framework.arg_scope(
        [layers.fully_connected, layers.conv2d_transpose],
        activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
        weights_regularizer=layers.l2_regularizer(weight_decay)):
        with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
                        zero_debias_moving_mean=True):

            net = layers.fully_connected(noise, 1024)
            net = layers.fully_connected(net, 7 * 7 * 128)
            net = tf.reshape(net, [-1, 7, 7, 128])
            net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
            net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
            # Make sure that generator output is in the same range as `inputs`
            # ie [-1, 1].
            net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)

            return net

Discriminator

如下所示,模型的输入除了img还有其对应的one_hot_labels,通过tfgan.features.condition_tensor_from_onehot函数将图像的分类特征与类别编码连接起来,进行最后的判别。

def conditional_discriminator(img, conditioning, weight_decay=2.5e-5,
                     is_training=True):
    """Discriminator network on MNIST digits.

    Args:
        img: Real or generated MNIST digits. Should be in the range [-1, 1].
        conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
            weight_decay: The L2 weight decay.
        weight_decay: The L2 weight decay.
        is_training: If `True`, batch norm uses batch statistics. If `False`, batch
            norm uses the exponential moving average collected from population 
            statistics.

    Returns:
        Logits for the probability that the image is real.
    """
    _, one_hot_labels = conditioning

    with tf.contrib.framework.arg_scope(
        [layers.conv2d, layers.fully_connected],
        activation_fn=tf.nn.relu, normalizer_fn=None,
        weights_regularizer=layers.l2_regularizer(weight_decay),
        biases_regularizer=layers.l2_regularizer(weight_decay)):

        net = layers.conv2d(img, 64, [4, 4], stride=2)
        net = layers.conv2d(net, 128, [4, 4], stride=2)
        net = layers.flatten(net)

        net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels)

        with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
            net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)

        return layers.linear(net, 1)

实验

CGAN的模型图如下所示:


graph

生成结果如下所示,经过2w次训练,模型基本能够生成我们指定的数据。

Loss:


loss

生成效果:


Epochs:6000
Epochs:10000
Epochs:20000

尝试过使用1e-4的生成器学习率,能够在1000步左右就产生正确的类别结果,但是生成数据清晰度不够,继续训练会发生model collapse。考虑训练前1000步使用1e-4,后面改用1e-5收敛效果会更好。

完整的CGAN代码如下所示:

import tensorflow as tf
import tensorflow.contrib.gan as tfgan
import tensorflow.contrib.layers as layers
from tensorflow.examples.tutorials.mnist import input_data


def float_image_to_uint8(image):
    """Convert float image in [-1, 1) to [0, 255] uint8.
    Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
    Args:
        image: An image tensor. Values should be in [-1, 1).
    Returns:
        Input image cast to uint8 and with integer values in [0, 255].
    """
    image = (image * 128.0) + 128.0

    return tf.cast(image, tf.uint8)


def provide_data(batch_size, num_threads=1):
    file = "MNIST"
    # range 0~1
    mnist = input_data.read_data_sets(file, one_hot=True)

    train_data = mnist.train.images.reshape(-1, 28, 28, 1) * 255
    train_labels = mnist.train.labels

    # transfer to -1~1
    train_data = (tf.to_float(train_data) - 128.0) / 128.0

    # Creates a QueueRunner for the pre-fetching operation.
    input_queue = tf.train.slice_input_producer([train_data, train_labels], shuffle=True)
    images, labels = tf.train.batch(
            input_queue,
            batch_size=batch_size,
            num_threads=num_threads,
            capacity=5 * batch_size)

    return images, labels


def conditional_generator(inputs, weight_decay=2.5e-5, is_training=True):
    """Simple generator to produce MNIST images.

    Args:
        noise: A 2-tuple of Tensors (noise, one_hot_labels) and creates a
                conditional generator.
        weight_decay: The value of the l2 weight decay.
        is_training: If `True`, batch norm uses batch statistics. If `False`, batch
            norm uses the exponential moving average collected from population 
            statistics.

    Returns:
        A generated image in the range [-1, 1].
    """
    noise, one_hot_labels = inputs
    noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)

    with tf.contrib.framework.arg_scope(
        [layers.fully_connected, layers.conv2d_transpose],
        activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
        weights_regularizer=layers.l2_regularizer(weight_decay)):
        with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training,
                        zero_debias_moving_mean=True):

            net = layers.fully_connected(noise, 1024)
            net = layers.fully_connected(net, 7 * 7 * 128)
            net = tf.reshape(net, [-1, 7, 7, 128])
            net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)
            net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)
            # Make sure that generator output is in the same range as `inputs`
            # ie [-1, 1].
            net = layers.conv2d(net, 1, [4, 4], normalizer_fn=None, activation_fn=tf.tanh)

            return net


def conditional_discriminator(img, conditioning, weight_decay=2.5e-5,
                     is_training=True):
    """Discriminator network on MNIST digits.

    Args:
        img: Real or generated MNIST digits. Should be in the range [-1, 1].
        conditioning: A 2-tuple of Tensors representing (noise, one_hot_labels).
            weight_decay: The L2 weight decay.
        weight_decay: The L2 weight decay.
        is_training: If `True`, batch norm uses batch statistics. If `False`, batch
            norm uses the exponential moving average collected from population 
            statistics.

    Returns:
        Logits for the probability that the image is real.
    """
    _, one_hot_labels = conditioning

    with tf.contrib.framework.arg_scope(
        [layers.conv2d, layers.fully_connected],
        activation_fn=tf.nn.relu, normalizer_fn=None,
        weights_regularizer=layers.l2_regularizer(weight_decay),
        biases_regularizer=layers.l2_regularizer(weight_decay)):

        net = layers.conv2d(img, 64, [4, 4], stride=2)
        net = layers.conv2d(net, 128, [4, 4], stride=2)
        net = layers.flatten(net)

        net = tfgan.features.condition_tensor_from_onehot(net, one_hot_labels)

        with tf.contrib.framework.arg_scope([layers.batch_norm], is_training=is_training):
            net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)

        return layers.linear(net, 1)


def train(batch_size, max_steps, gen_lr, dis_lr, train_log_dir):
    tf.reset_default_graph()

    if not tf.gfile.Exists(train_log_dir):
        tf.gfile.MakeDirs(train_log_dir)

    # Set up the input.
    images, one_hot_labels = provide_data(batch_size)
    noise = tf.random_normal([batch_size, 64])

    with tf.name_scope('model'):
        # Build the generator and discriminator.
        gan_model = tfgan.gan_model(
            generator_fn=conditional_generator,  # you define 
            discriminator_fn=conditional_discriminator,  # you define
            real_data=images,
            generator_inputs=(noise, one_hot_labels))

    with tf.name_scope('loss'):
        # Build the GAN loss.
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
            discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
            gradient_penalty_weight=1.0,
            add_summaries=True)

    with tf.name_scope('train'):
        # Create the train ops, which calculate gradients and apply updates to weights.
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            check_for_unused_update_ops=False,
            summarize_gradients=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    # Run the train ops in the alternating training scheme.
    tfgan.gan_train(
        train_ops,
        hooks=[tf.train.StopAtStepHook(num_steps=max_steps)],
        logdir=train_log_dir,
        save_summaries_steps=10)


def test(eval_dir, checkpoint_dir):
    tf.reset_default_graph()

    if not tf.gfile.Exists(eval_dir):
        tf.gfile.MakeDirs(eval_dir)

    noises = tf.random_normal([100, 64])
    c = [i for i in range(10) for j in range(10)]
    conditions = tf.one_hot(c, 10)
    random_inputs = (noises, conditions)

    with tf.variable_scope('Generator'):
        images = conditional_generator(random_inputs, is_training=False)

    reshaped_images = tfgan.eval.image_reshaper(images[:100, ...], num_cols=10)
    uint8_images = float_image_to_uint8(reshaped_images)

    image_write_ops = tf.write_file(
          '%s/%s' % (eval_dir, 'conditional_gan.png'),
          tf.image.encode_png(uint8_images[0]))

    tf.contrib.training.evaluate_repeatedly(
            checkpoint_dir,
            eval_ops=image_write_ops,
            hooks=[tf.contrib.training.StopAfterNEvalsHook(1)],
            max_number_of_evaluations=1)


if __name__ == '__main__':
    train(14, 10000, 1e-5, 1e-4, 'cg_logs/')
    test('cg_eval/', 'cg_logs/')

你可能感兴趣的:(TFGAN实现Conditional GAN)