Conditional GAN (CGAN)

Conditional GAN不是随机生成图像,而是根据我们需要的label生成图像:

代码:

import numpy as np

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

import matplotlib.pyplot as plt

import os

import matplotlib.gridspec as gridspec

######################

''' label for fake image is 0, for real image is 1'''

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

iterations = 100000

batch_size = 128

noise_dim = 100


######################sample generation

def sample_Z(m, n):
    '''

    generate noise sample

    :param m: batch_size

    :param n:

    :return:

    '''

    return np.random.uniform(-1., 1., size=[m, n])
###################
onehot = np.eye(10)

#################net parameters

Gw1 = tf.get_variable(name='Gw1', shape=[noise_dim + 10, 128], dtype=tf.float32,
                      initializer=tf.contrib.layers.xavier_initializer())

Gw2 = tf.get_variable(name='Gw2', shape=[128, 784], dtype=tf.float32,
                      initializer=tf.contrib.layers.xavier_initializer())

Gb1 = tf.get_variable(name='Gb1', shape=[128], dtype=tf.float32, initializer=tf.constant_initializer(0))

Gb2 = tf.get_variable(name='Gb2', shape=[784], dtype=tf.float32, initializer=tf.constant_initializer(0))

params_G = [Gw1, Gw2, Gb1, Gb2]

Z = tf.placeholder(tf.float32, shape=[None, noise_dim], name='Z')

Dw1 = tf.get_variable(name='Dw1', shape=[784 + 10, 128], dtype=tf.float32,
                      initializer=tf.contrib.layers.xavier_initializer())

Dw2 = tf.get_variable(name='Dw2', shape=[128, 1], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())

Db1 = tf.get_variable(name='Db1', shape=[128], dtype=tf.float32, initializer=tf.constant_initializer(0))

Db2 = tf.get_variable(name='Db2', shape=[1], dtype=tf.float32, initializer=tf.constant_initializer(0))

params_D = [Dw1, Dw2, Db1, Db2]

X = tf.placeholder(tf.float32, shape=[None, 784], name='X')

Y = tf.placeholder(tf.float32, shape=[None, 10], name='Y')
#################generation net


def generator(z1, y1):

    cat = tf.concat([z1, y1], axis=1)

    z2 = tf.nn.relu(tf.matmul(cat, Gw1) + Gb1)

    G_prob = tf.nn.sigmoid(tf.matmul(z2, Gw2) + Gb2)

    return G_prob


#################discrimination net


def discriminator(x1, y1):

    cat = tf.concat([x1, y1], axis=1)

    x2 = tf.nn.relu(tf.matmul(cat, Dw1) + Db1)

    D_prob = tf.nn.sigmoid(tf.matmul(x2, Dw2) + Db2)

    return D_prob


G_sample = generator(Z, Y)

G_loss = -tf.reduce_mean(tf.log(discriminator(generator(Z, Y), Y)))

D_loss = -tf.reduce_mean(tf.log(1. - discriminator(generator(Z, Y), Y)) + tf.log(discriminator(X, Y)))

G_optimizer = tf.train.AdamOptimizer(0.001).minimize(G_loss, var_list=params_G)

D_optimizer = tf.train.AdamOptimizer(0.001).minimize(D_loss, var_list=params_D)


#############################################

def plot(samples):
    fig = plt.figure(figsize=(4, 4))

    gs = gridspec.GridSpec(4, 4)

    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):  # [i,samples[i]] imax=16

        ax = plt.subplot(gs[i])

        plt.axis('off')

        ax.set_xticklabels([])

        ax.set_aspect('equal')

        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


if not os.path.exists('out/'):
    os.makedirs('out/')

j = 0

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    for i in range(1000000):

        #####lossG
        ysamples = np.random.randint(0, 9, (batch_size))
        ysamples = onehot[ysamples]
        lossG, _ = sess.run([G_loss, G_optimizer], feed_dict={Z: sample_Z(batch_size, noise_dim), Y: ysamples})




        #####lossD
        mnistsamples = mnist.train.next_batch(batch_size)
        lossD, _ = sess.run([D_loss, D_optimizer],
                            feed_dict={Z: sample_Z(batch_size, noise_dim), X: mnistsamples[0], Y: mnistsamples[1]})

        if (i + 1) % 200 == 0:

            ysamples = np.array([0]*4 + [1]*4 + [2]*4 + [3]*4, dtype=int)
            ysamples = onehot[ysamples]
            zsamples = sample_Z(16, noise_dim)


            samples = sess.run(G_sample, feed_dict={Z:zsamples, Y:ysamples

                })  # 16*784

            fig = plot(samples)

            plt.savefig('out/{}.png'.format(str(j).zfill(3)), bbox_inches='tight')

            j += 1

            plt.close(fig)

            print('iters is %d, lossG is %4f, lossD is %4f' % (i, lossG, lossD))














 

你可能感兴趣的:(Conditional GAN (CGAN))