浅谈DCGAN

要想讲清楚DCGAN就要必须先说GAN,毕竟DCGAN是前者的升级版。

GAN也就是对抗生成网络,所谓对抗生成,也就是有两个网络,一个D(判别),另一个G(生成)。两个网络的目标函数,是不一样的。G网络的目标函数是让生成的图片,在D网络中尽量判别为真,而D网络的目标就是能最大限度的判别出输入图片,是由G网络生成的,还是非生成的。

DCGAN也是这样的思想,只不过加上了卷积层,来更好的实现。有4个特点:

1、D网络模型,使用的是带步长的卷积取代池化层,进行下采样。

2、G网络模型,使用,进行上采样。

3、激活函数为LeakyReLu

4、使用Batch Normalization标准化

下面展示一下用mnist数据集,实现DCGAN。

导入数据,定义真实输入和噪音向量。

import numpy as np
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets()

def get_inputs(noise_dim, image_height, image_width, image_depth):
    inputs_real = tf.placeholder(tf.float32, [None, image_height, image_width, image_depth], name='inputs_real')
    inputs_noise = tf.placeholder(tf.float32, [None, noise_dim], name='inputs_noise')
    return inputs_real, inputs_noise
定义一个生成网络,也就是G网络。注意只在最后使用Tanh激活函数,其他反卷积层都使用relu激活函数。
def get_generator(noise_img, output_dim, is_train=True):

    with tf.variable_scope("generator", reuse=(not is_train)):
        layer1 = tf.layers.dense(noise_img, 4*4*512)
        layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
        #batch normalization
        layer1 = tf.layers.batch_normalization(layer1, training=is_train)
        layer1 = tf.nn.relu(layer1)
        #dropout
        layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
        #conv
        layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=(1, 1), padding='valid')
        layer2 = tf.layers.batch_normalization(layer2, training=is_train)
        layer2 = tf.nn.relu(layer2)
        layer2 = tf.nn.dropout(layer2, keep_prob=0.8)

        layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=(2, 2), padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=is_train)
        layer3 = tf.nn.relu(layer3)
        layer3 = tf.nn.dropout(layer3, keep_prob=0.8)

        logits = tf.layers.conv2d_transpose(layer3, output_dim, 3, strides=(2,2), padding='same')
        outputs = tf.tanh(logits)
        return outputs

定义判别网络,D网络。每一个卷积层都是使用LeakReLu激活函数。

在定义目标函数的时候,我这里G网络的loss使用的是真实图片,作为label。这是运用的Improved GAN,这篇论文主要有5点思想,这里有篇博客总结很好


def lrelu(layer, leak=0.2, name="lrelu"):
  return tf.maximum(layer, leak*layer)

def get_discriminator(inputs_img, reuse=False):

    with tf.variable_scope("discriminator", reuse=reuse):
        layer1 = tf.layers.conv2d(inputs_img, 128, 3, strides=2, padding='same')
        layer1 = lrelu(layer1)
        layer1 = tf.nn.dropout(layer1, keep_prob=0.8)

        layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
        layer2 = tf.layers.batch_normalization(layer2, training=True)
        layer2 = lrelu(layer2)
        layer2 = tf.nn.dropout(layer2, keep_prob=0.8)

        layer3 = tf.layers.conv2d(layer2, 512, 3, strides=(2, 2), padding='same')
        layer3 = tf.layers.batch_normalization(layer3, training=True)
        layer3 = lrelu(layer3)
        layer3 = tf.nn.dropout(layer3, keep_prob=0.8)

        flatten = tf.reshape(layer3, (-1, 4*4*512))
        logits = tf.layers.dense(flatten, 1)
        outputs = tf.sigmoid(logits)
        return logits, outputs
def get_loss(inputs_real, inputs_noise, image_depth, smooth=0.1):

    g_outputs = get_generator(inputs_noise, image_depth, is_train=True)
    d_logits_real, d_outputs_real = get_discriminator(inputs_real)
    d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, reuse=True)

    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits((logits=d_logits_fake, labels=d_logits_real))
    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                         labels=tf.ones_like(d_outputs_real)*(1-smooth)))
    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_outputs_fake)))
    d_loss = tf.add(d_loss_real, d_loss_fake)
    return g_loss, d_loss

def get_optimizer(g_loss, d_loss, learning_rate):
    train_vars = tf.trainable_variables()
    g_vars = [var for var in train_vars if var.name.startwith("generator")]
    d_vars = [var for var in train_vars if var.name.startwith("discriminator")]
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        g_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
        d_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)

    return g_opt, d_opt

def plot_images(samples):
    fig, axes = plt.subplots(nrows=1, ncols=25, sharex=True, sharey=True, figsize=(50, 2))
    for img, ax in zip(samples, axes):
        ax.imshow(img.reshape(28, 28), cmap='Greys_r')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0)


def show_generator_output(sess, n_images, inputs_noise, output_dim):
    cmap = 'Greys_r'
    noise_shape = inputs_noise.get_shape().as_list()[-1]
    # 生成噪声图片
    examples_noise = np.random.uniform(-1, 1, size=[n_images, noise_shape])

    samples = sess.run(get_generator(inputs_noise, output_dim, False),
                       feed_dict={inputs_noise: examples_noise})

    result = np.squeeze(samples, -1)
    return result


batch_size = 64
noise_size = 100
epochs = 5
n_samples = 25
learning_rate = 0.001


def train(noise_size, data_shape, batch_size, n_samples):
    # 存储loss
    losses = []
    steps = 0

    inputs_real, inputs_noise = get_inputs(noise_size, data_shape[1], data_shape[2], data_shape[3])
    g_loss, d_loss = get_loss(inputs_real, inputs_noise, data_shape[-1])
    g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, learning_rate)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # 迭代epoch
        for e in range(epochs):
            for batch_i in range(mnist.train.num_examples // batch_size):
                steps += 1
                batch = mnist.train.next_batch(batch_size)

                batch_images = batch[0].reshape((batch_size, data_shape[1], data_shape[2], data_shape[3]))
                # scale to -1, 1
                batch_images = batch_images * 2 - 1

                # noise
                batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))

                # run optimizer
                _ = sess.run(g_train_opt, feed_dict={inputs_real: batch_images,
                                                     inputs_noise: batch_noise})
                _ = sess.run(d_train_opt, feed_dict={inputs_real: batch_images,
                                                     inputs_noise: batch_noise})

                if steps % 101 == 0:
                    train_loss_d = d_loss.eval({inputs_real: batch_images,
                                                inputs_noise: batch_noise})
                    train_loss_g = g_loss.eval({inputs_real: batch_images,
                                                inputs_noise: batch_noise})
                    losses.append((train_loss_d, train_loss_g))
                    # 显示图片
                    samples = show_generator_output(sess, n_samples, inputs_noise, data_shape[-1])
                    plot_images(samples)
                    print("Epoch {}/{}....".format(e + 1, epochs),
                          "Discriminator Loss: {:.4f}....".format(train_loss_d),
                          "Generator Loss: {:.4f}....".format(train_loss_g))

with tf.Graph().as_default():
    train(noise_size, [-1, 28, 28, 1], batch_size, n_samples)

作为近期很热门的GAN,相关的论文有很多,也有很多优化、改良的方法。 这篇博客做了很好的归纳。

最后,我也是在持续学习中,如果,哪里有不正确的,或者建议,希望能指出。

祝近安!



你可能感兴趣的:(浅谈DCGAN)