生成对抗网络(GAN)及其变种已经成为最近十年以来机器学习领域最为重要的思想。--2018图灵奖得主 Yann LeCun
GAN的基础知识复习:click here
GAN模型的挑战即训练优化:click here
本程序主要是采用最初GAN的基本原理,选择简单的二层神经网络以及MNIST数据集并基于tensorflow平台来实现,以求得对GAN的原理以及实现过程有一个更深入的理解。
程序框架及注释:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
sess = tf.InteractiveSession()
z_dim = 100
batchs= 128
mnist = input_data.read_data_sets("/home/zhaocq/桌面/tensorflow/mnist/raw/",one_hot=True)
def weight_variable(shape,name):
initial = tf.random_normal(shape, stddev=0.01)
return tf.Variable(initial,name = name)
def bias_variable(shape,name):
initial = tf.zeros(shape) #给偏置增加小的正值用来避免死亡节点;
return tf.Variable(initial,name = name)
#生成器随机噪声100维
z = tf.placeholder(tf.float32,shape=[None,100],name = 'z')
#鉴别器准备MNIST图像输入设置
x = tf.placeholder(tf.float32,shape=[None,784],name = 'x')
#生成器参数定义
g_w1 = weight_variable([100,128],'g_w1')
g_b1 = bias_variable([128],'g_b1')
g_w2 = weight_variable([128,784],'g_w2')
g_b2 = bias_variable([784],'g_b2')
generator_dict = [g_w1,g_b1,g_w2,g_b2]
#鉴别器参数定义
d_w1 = weight_variable([784,128],'d_w1')
d_b1 = bias_variable([128],'d_b1')
d_w2 = weight_variable([128,1],'d_w2')
d_b2 = bias_variable([1],'d_b2')
discriminator_dict = [d_w1,d_b1,d_w2,d_b2]
#生成器网络定义
def generator(z):
g_h1 = tf.nn.relu(tf.matmul(z,g_w1) + g_b1)
g_h2 = tf.nn.sigmoid(tf.matmul(g_h1,g_w2) + g_b2)
return g_h2
#定义鉴别器
def discrimnator(x):
d_h1 = tf.nn.relu(tf.matmul(x,d_w1)+d_b1)
d_logit = tf.matmul(d_h1,d_w2)+d_b2
d_prob = tf.nn.sigmoid(d_logit)
return d_prob,d_logit
g_sample = generator(z)
d_real,d_logit_real = discrimnator(x)
d_fake,d_logit_fake = discrimnator(g_sample)
#定义损失
d_loss = - tf.reduce_mean(tf.log(d_real) + tf.log(1.- d_fake))
g_loss = - tf.reduce_mean(tf.log(d_fake))
#定义优化器,仅优化相关参数
d_slover = tf.train.AdamOptimizer().minimize(d_loss,var_list = discriminator_dict)
g_slover = tf.train.AdamOptimizer().minimize(g_loss,var_list = generator_dict)
def sample_z(m,n):
'''Uniform prior for G(z)'''
return np.random.uniform(-1.,1.,size=[m,n])
def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig
#训练过程
sess.run(tf.global_variables_initializer())
i=0
for it in range(500000):
#输出image
if it % 10000 == 0:
samples = sess.run(g_sample,feed_dict={z: sample_z(16,100)})
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)),bbox_inches = 'tight')
i += 1
plt.close(fig)
x_md,_ =mnist.train.next_batch(batchs)
_,d_loss_curr = sess.run([d_slover,d_loss],feed_dict={x: x_md, z: sample_z(batchs,z_dim)})
_,g_loss_curr = sess.run([g_slover,g_loss],feed_dict={z: sample_z(batchs,z_dim)})
if it % 10000 ==0:
print('iter:{}'.format(it))
print('d loss : {:.4}'.format(d_loss_curr))
print('g loss : {:.4}'.format(g_loss_curr))
print()
#测试test
sampl = sess.run(g_sample,feed_dict={z: sample_z(5,100)})
I=np.reshape(sampl[1],(28,28))
#plt.imshow(np.reshape(sampl[1],(28,28)))
plt.imshow(I)
第100000次迭代结果: