import tensorflow as tf
import numpy as np
import pickle
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/')
img_size = mnist.train.images[0].shape[0]
noise_size = 100
g_units = 128
d_units = 128
alpha = 0.01
learning_rate = 0.001
smooth = 0.1
batch_size = 64
epochs = 50
n_sample = 25
class GAN():
def __init__(self, img_size, noise_size):
self.real_img = tf.placeholder(tf.float32, [None, img_size], name='real_img')
self.noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img')
@staticmethod
def get_generator( noise_img, n_units, out_dim, reuse=False, alpha=0.01):
"""
生成器
noise_img: 生成器的输入
n_units: 隐层单元个数
out_dim: 生成器输出tensor的size,这里应该为32*32=784
alpha: leaky ReLU系数
"""
with tf.variable_scope("generator", reuse=reuse):
hidden1 = tf.layers.dense(noise_img, n_units)
hidden1 = tf.maximum(alpha * hidden1, hidden1)
hidden1 = tf.layers.dropout(hidden1, rate=0.2)
logits = tf.layers.dense(hidden1, out_dim)
outputs = tf.tanh(logits)
return logits, outputs
@staticmethod
def get_discriminator( img, n_units, reuse=False, alpha=0.01):
"""
判别器
n_units: 隐层结点数量
alpha: Leaky ReLU系数
"""
with tf.variable_scope("discriminator", reuse=reuse):
hidden1 = tf.layers.dense(img, n_units)
hidden1 = tf.maximum(alpha * hidden1, hidden1)
logits = tf.layers.dense(hidden1, 1)
outputs = tf.sigmoid(logits)
return logits, outputs
def draw_loss(self,losses):
fig, ax = plt.subplots(figsize=(20,7))
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator Total Loss')
plt.plot(losses.T[1], label='Discriminator Real Loss')
plt.plot(losses.T[2], label='Discriminator Fake Loss')
plt.plot(losses.T[3], label='Generator')
plt.title("Training Losses")
plt.legend()
def inference(self):
g_logits, g_outputs = GAN.get_generator(self.noise_img, g_units, img_size)
d_logits_real, d_outputs_real = GAN.get_discriminator(self.real_img, d_units)
d_logits_fake, d_outputs_fake = GAN.get_discriminator(g_outputs, d_units, reuse=True)
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_logits_real)) * (1 - smooth))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_logits_fake)))
self.d_loss = tf.add(self.d_loss_real, self.d_loss_fake)
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_logits_fake)) * (1 - smooth))
train_vars = tf.trainable_variables()
self.g_vars = [var for var in train_vars if var.name.startswith("generator")]
self.d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(self.d_loss, var_list=self.d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(self.g_loss, var_list=self.g_vars)
self.saver = tf.train.Saver(var_list=self.g_vars)
return d_train_opt,g_train_opt
def train(self, d_train_opt,g_train_opt):
samples = []
losses = []
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for e in range(epochs):
for batch_i in range(mnist.train.num_examples//batch_size):
batch = mnist.train.next_batch(batch_size)
batch_images = batch[0].reshape((batch_size, 784))
batch_images = batch_images*2 - 1
batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
_ = sess.run(d_train_opt, feed_dict={self.real_img: batch_images, self.noise_img: batch_noise})
_ = sess.run(g_train_opt, feed_dict={self.noise_img: batch_noise})
train_loss_d = sess.run(self.d_loss,
feed_dict = {self.real_img: batch_images,
self.noise_img: batch_noise})
train_loss_d_real = sess.run(self.d_loss_real,
feed_dict = {self.real_img: batch_images,
self.noise_img: batch_noise})
train_loss_d_fake = sess.run(self.d_loss_fake,
feed_dict = {self.real_img: batch_images,
self.noise_img: batch_noise})
train_loss_g = sess.run(self.g_loss,
feed_dict = {self.noise_img: batch_noise})
print("Epoch {}/{}...".format(e+1, epochs),
"Discriminator Loss: {:.4f}(Real: {:.4f} + Fake: {:.4f})...".format(train_loss_d, train_loss_d_real, train_loss_d_fake),
"Generator Loss: {:.4f}".format(train_loss_g))
losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))
sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))
gen_samples = sess.run(GAN.get_generator(self.noise_img, g_units, img_size, reuse=True),
feed_dict={self.noise_img: sample_noise})
samples.append(gen_samples)
self.saver.save(sess, './checkpoints/generator_1.ckpt')
with open('train_samples_1.pkl', 'wb') as f:
pickle.dump(samples, f)
return losses
@staticmethod
def view_samples( samples):
fig, axes = plt.subplots(figsize=(7,7), nrows=5, ncols=5, sharey=True, sharex=True)
for ax, img in zip(axes.flatten(), samples):
ax.axis('off')
ax.imshow(img.reshape((28,28)), cmap='Greys_r')
return fig, axes
def draw_samples(self):
epoch_idx = [0, 5, 10, 20, 40]
show_imgs = []
with open('train_samples_1.pkl', 'rb') as f:
samples = pickle.load(f)
for i in epoch_idx:
show_imgs.append(samples[i][1])
rows, cols = 10, 25
fig, axes = plt.subplots(figsize=(30,12), nrows=rows, ncols=cols, sharex=True, sharey=True)
for sample, ax_row in zip(show_imgs, axes):
for img, ax in zip(sample, ax_row):
ax.imshow(img.reshape((28,28)), cmap='Greys_r')
ax.axis('off')
def test(self):
with tf.Session() as sess:
self.saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
sample_noise = np.random.uniform(-1, 1, size=(25, 100))
logits,gen_samples = sess.run(GAN.get_generator(self.noise_img, g_units, img_size, reuse=True),
feed_dict={self.noise_img: sample_noise})
GAN.view_samples( gen_samples)
return gen_samples
gan = GAN(img_size, noise_size)
d_train_opt,g_train_opt=gan.inference()
gen_samples=gan.test()