GAN的简单实现--MNIST数据集+tensorflow

生成对抗网络(GAN)及其变种已经成为最近十年以来机器学习领域最为重要的思想。--2018图灵奖得主 Yann LeCun

GAN的基础知识复习:click here

GAN模型的挑战即训练优化:click here

1、模型简介及代码

本程序主要是采用最初GAN的基本原理,选择简单的二层神经网络以及MNIST数据集并基于tensorflow平台来实现,以求得对GAN的原理以及实现过程有一个更深入的理解。

程序框架及注释:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


sess = tf.InteractiveSession()
z_dim = 100
batchs= 128
mnist = input_data.read_data_sets("/home/zhaocq/桌面/tensorflow/mnist/raw/",one_hot=True)



def weight_variable(shape,name):
    initial = tf.random_normal(shape, stddev=0.01)
    return tf.Variable(initial,name = name)
def bias_variable(shape,name):
    initial = tf.zeros(shape)  #给偏置增加小的正值用来避免死亡节点;
    return tf.Variable(initial,name = name)
#生成器随机噪声100维
z = tf.placeholder(tf.float32,shape=[None,100],name = 'z')
#鉴别器准备MNIST图像输入设置
x = tf.placeholder(tf.float32,shape=[None,784],name = 'x')
#生成器参数定义
g_w1 = weight_variable([100,128],'g_w1')
g_b1 = bias_variable([128],'g_b1')
g_w2 = weight_variable([128,784],'g_w2')
g_b2 = bias_variable([784],'g_b2')
generator_dict = [g_w1,g_b1,g_w2,g_b2]
#鉴别器参数定义
d_w1 = weight_variable([784,128],'d_w1')
d_b1 = bias_variable([128],'d_b1')
d_w2 = weight_variable([128,1],'d_w2')
d_b2 = bias_variable([1],'d_b2')
discriminator_dict = [d_w1,d_b1,d_w2,d_b2]



#生成器网络定义
def generator(z):
    g_h1 = tf.nn.relu(tf.matmul(z,g_w1) + g_b1)
    g_h2 = tf.nn.sigmoid(tf.matmul(g_h1,g_w2) + g_b2)
    return g_h2
#定义鉴别器
def discrimnator(x):
    d_h1 = tf.nn.relu(tf.matmul(x,d_w1)+d_b1)
    d_logit = tf.matmul(d_h1,d_w2)+d_b2
    d_prob = tf.nn.sigmoid(d_logit)
    return d_prob,d_logit


g_sample = generator(z)
d_real,d_logit_real = discrimnator(x)
d_fake,d_logit_fake = discrimnator(g_sample)
#定义损失
d_loss = - tf.reduce_mean(tf.log(d_real) + tf.log(1.- d_fake))
g_loss = - tf.reduce_mean(tf.log(d_fake))
#定义优化器,仅优化相关参数
d_slover = tf.train.AdamOptimizer().minimize(d_loss,var_list = discriminator_dict)
g_slover = tf.train.AdamOptimizer().minimize(g_loss,var_list = generator_dict)



def sample_z(m,n):
    '''Uniform prior for G(z)'''
    return np.random.uniform(-1.,1.,size=[m,n])
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

#训练过程
sess.run(tf.global_variables_initializer())
i=0
for it in range(500000):
    #输出image
    if it % 10000 == 0:
        samples = sess.run(g_sample,feed_dict={z: sample_z(16,100)})
        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)),bbox_inches = 'tight')
        i += 1
        plt.close(fig)
    x_md,_ =mnist.train.next_batch(batchs)
    _,d_loss_curr = sess.run([d_slover,d_loss],feed_dict={x: x_md, z: sample_z(batchs,z_dim)})
    _,g_loss_curr = sess.run([g_slover,g_loss],feed_dict={z: sample_z(batchs,z_dim)})
    
    if it % 10000 ==0:
        print('iter:{}'.format(it))
        print('d loss : {:.4}'.format(d_loss_curr))
        print('g loss : {:.4}'.format(g_loss_curr))
        print()


#测试test
sampl = sess.run(g_sample,feed_dict={z: sample_z(5,100)})
I=np.reshape(sampl[1],(28,28))
#plt.imshow(np.reshape(sampl[1],(28,28)))
plt.imshow(I)

第100000次迭代结果:

GAN的简单实现--MNIST数据集+tensorflow_第1张图片

你可能感兴趣的:(深度学习,tensorflow)