基于Minist数据集实现WGAN


前置:minimize讲解

# coding=utf-8
from scipy.optimize import minimize
import numpy as np
 
#demo 1
#计算 1/x+x 的最小值
 def fun(args):
     a=args
     v=lambda x:a/x[0] +x[0]
     return v
 
 if __name__ == "__main__":
     args = (1)  #a
     x0 = np.asarray((2))  # 初始猜测值
     res = minimize(fun(args), x0, method='SLSQP')
     print(res.fun)
     print(res.success)
     print(res.x)

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
real_img = tf.placeholder(tf.float32, [None, 784], name='real_img')
noise_img = tf.placeholder(tf.float32, [None, 100], name='noise_img')
def generator(noise_img, hidden_units, out_dim, reuse=False):
    with tf.variable_scope("generator", reuse=reuse):
        hidden1 = tf.layers.dense(noise_img, hidden_units)
        hidden1 = tf.nn.relu(hidden1)

        # logits & outputs
        logits = tf.layers.dense(hidden1, out_dim)
        outputs = tf.nn.sigmoid(logits)
        
        return logits, outputs
def discriminator(img, hidden_units, reuse=False):
    with tf.variable_scope("discriminator", reuse=reuse):
        # hidden layer
        hidden1 = tf.layers.dense(img, hidden_units)
        hidden1 = tf.nn.relu(hidden1)
        
        # logits & outputs
        outputs = tf.layers.dense(hidden1, 1)
        
        return outputs
def plot_images(samples):
    fig, axes = plt.subplots(nrows=1, ncols=25, sharex=True, sharey=True, figsize=(50,2))
    for img, ax in zip(samples, axes):
        ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    fig.tight_layout(pad=0)
img_size = 784
noise_size = 100
hidden_units = 128
learning_rate = 0.0001
g_logits, g_outputs = generator(noise_img, hidden_units, img_size)

# discriminator
d_real = discriminator(real_img, hidden_units)
d_fake = discriminator(g_outputs, hidden_units, reuse=True)

# 希望d_real越大越好  d_fake越小越好 , 即希望d_loss越大越好,即 -d_loss越小越好
d_loss = tf.reduce_mean(d_real) - tf.reduce_mean(d_fake)
# 希望d_fake越大越好,即-d_fake越小越好
g_loss = -tf.reduce_mean(d_fake)
# d_vars是判别器的训练参数  g_vars是生成器的训练参数
#d_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(-d_loss, var_list=d_vars)
#g_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
#clip_d = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars] 

train_vars = tf.trainable_variables()

# d_vars是判别器的训练参数  g_vars是生成器的训练参数
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

# d_vars是判别器的训练参数  g_vars是生成器的训练参数
d_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(-d_loss, var_list=d_vars)
g_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

# 控制d_vars中的值
clip_d = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars]         # 对discriminator变量使用clip
batch_size = 32

n_sample = 25

samples = []

losses = []

saver = tf.train.Saver(var_list = g_vars)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for it in range(1000000):
        for _ in range(5):
            batch_images, _ = mnist.train.next_batch(batch_size)
            
            batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
            _ = sess.run([d_train_opt, clip_d], feed_dict={real_img: batch_images, noise_img: batch_noise})
            
        _ = sess.run(g_train_opt, feed_dict={noise_img: batch_noise})   

        if it%10000 == 0:
            sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))
            _, samples = sess.run(generator(noise_img, hidden_units, img_size, reuse=True),
                                   feed_dict={noise_img: sample_noise})
            plot_images(samples)
            
        
        # 每一轮结束计算loss
            train_loss_d = sess.run(d_loss, 
                                feed_dict = {real_img: batch_images, 
                                             noise_img: batch_noise})
        # generator loss
            train_loss_g = sess.run(g_loss, 
                                feed_dict = {noise_img: batch_noise})
        
            
            print("Epoch {}/{}...".format(it, 1000000),
              "Discriminator Loss: {:.4f}...".format(train_loss_d),
              "Generator Loss: {:.4f}".format(train_loss_g))    
        # 记录各类loss值
            losses.append((train_loss_d, train_loss_g))
        # 存储checkpoints
        saver.save(sess, './checkpoints/generator.ckpt')

基于Minist数据集实现WGAN_第1张图片

你可能感兴趣的:(基于Minist数据集实现WGAN)