生成对抗网络GANs(Generative Adversarial Nets
from datetime import datetime
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tensorflow.examples.tutorials.mnist import input_data
BATCH_SIZE = 128
LEARNING_RATE = 1e-4
Z_DIM = 100
IMAGE_W = 28
IMAGE_H = 28
model_dir = 'model_gan'
x_in = tf.placeholder(tf.float32, shape=[None, 784])
def load_mnist():
return input_data.read_data_sets("./MNIST_data", one_hot=True)
mnist = load_mnist()
def get_W_b(input_dim, output_dim, name):
W = tf.Variable(tf.random_normal([input_dim, output_dim], stddev=0.02), name=name.replace('_b', ''))
b = tf.Variable(tf.zeros([output_dim], tf.float32), name=name.replace('_W', ''))
return W, b
tmp = 256
class GAN(object):
def __init__(self, lr=LEARNING_RATE, batch_size=BATCH_SIZE, z_dim=Z_DIM):
self.lr = lr
self.batch_size = batch_size
self.z_dim = z_dim
# 生成器的权重
self.gen_W1, self.gen_b1 = get_W_b(z_dim, tmp, 'gen_W_b_1')
self.gen_W2, self.gen_b2 = get_W_b(tmp, IMAGE_H * IMAGE_W, 'gen_W_b_2')
# 判别器的权重
self.discrim_W1, self.discrim_b1 = get_W_b(IMAGE_H * IMAGE_W, tmp, 'discrim_W_b_1')
self.discrim_W2, self.discrim_b2 = get_W_b(tmp, 1, 'discrim_W_b_2')
# 判别器
def discriminator(self, x):
d_h1 = tf.nn.relu(tf.add(tf.matmul(x, self.discrim_W1), self.discrim_b1))
d_h2 = tf.add(tf.matmul(d_h1, self.discrim_W2), self.discrim_b2)
return tf.nn.sigmoid(d_h2)
# 生成器
def generator(self, z):
g_h1 = tf.nn.relu(tf.add(tf.matmul(z, self.gen_W1), self.gen_b1))
g_h2 = tf.add(tf.matmul(g_h1, self.gen_W2), self.gen_b2)
return tf.nn.sigmoid(g_h2)
# 建立模型
def build_model(self):
z_sample = np.random.uniform(-1., 1., size=[self.batch_size, self.z_dim]).astype('float32')
g_image = self.generator(z_sample)
d_real = self.discriminator(x_in)
d_fake = self.discriminator(g_image)
d_cost = -tf.reduce_mean(tf.log(d_real) + tf.log(1. - d_fake))
g_cost = -tf.reduce_mean(tf.log(d_fake))
return d_cost, g_cost, tf.reduce_mean(d_real), tf.reduce_mean(d_fake)
# 画图
def plot_grid(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(IMAGE_H, IMAGE_W), cmap='Greys_r')
return fig
# 训练
def train():
with tf.Session() as sess:
gan = GAN()
discrim_vars = list(filter(lambda x: x.name.startswith('discrim'), tf.trainable_variables()))
gen_vars = list(filter(lambda x: x.name.startswith('gen'), tf.trainable_variables()))
d_cost, g_cost, d_real, d_fake = gan.build_model()
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
d_opt = optimizer.minimize(d_cost, var_list=discrim_vars)
g_opt = optimizer.minimize(g_cost, var_list=gen_vars)
saver = tf.train.Saver()
checkpoint = tf.train.latest_checkpoint(model_dir)
if checkpoint:
saver.restore(sess, checkpoint) # 从模型中读取数据
print('checkpoint: {}'.format(checkpoint))
else:
# 变量初始化
sess.run(tf.global_variables_initializer())
print("Started training {}".format(datetime.now().isoformat()[11:]))
plot_index = 0
for step in range(100000):
batch_x, _ = mnist.train.next_batch(BATCH_SIZE)
sess.run(d_opt, feed_dict={x_in: batch_x})
sess.run(g_opt, feed_dict={x_in: batch_x})
# 每1000个step保存一次图片
if step % 1000 == 0:
batch_x, _ = mnist.train.next_batch(BATCH_SIZE)
d_cost_, d_real_, d_fake_ = sess.run([d_cost, d_real, d_fake], feed_dict={x_in: batch_x})
g_cost_ = sess.run(g_cost, feed_dict={x_in: batch_x})
print("step:{} Discriminator Loss {} Generator loss {} d_real:{} d_feak:{}".format(step, d_cost_,
g_cost_, d_real_,
d_fake_))
z_sample = np.random.uniform(-1., 1., size=[16, Z_DIM]).astype('float32')
g_image = sess.run(gan.generator(z_sample))
fig = plot_grid(g_image)
plt.savefig('D:\project\生成对抗网络\img\{}.png'.format(str(plot_index).zfill(4)), bbox_inches='tight')
plot_index += 1
plt.close(fig)
# 保存模型
saver.save(sess, "{}/model_gan.model".format(model_dir), global_step=step)
print("Ended training {}".format(datetime.now().isoformat()[11:]))
if __name__ == "__main__":
train()