import tensorflow as tf
import numpy as np
class parameters():
"""一个定义基础参数的类"""
def __init__(self):
self.encoder1 = 500
self.encoder2 = 500
self.decoder1 = 500
self.decoder2 = 500
self.image_size = 784
self.batch_size = 40
self.max_step = 30000
self.z_nodes = 20
def xavier_init(fan_in, fan_out, constant=1):
""" Xavier initialization of network weights"""
low = -constant*np.sqrt(6.0 / (fan_in + fan_out))
high = constant*np.sqrt(6.0 / (fan_in + fan_out))
return tf.random_uniform((fan_in, fan_out), minval=low, maxval=high, dtype=tf.float32)
def encoder(parameters, x_ph, reuse, name):
"""编码过程,输出一个关于Z分布的均值和方差的对数形式"""
with tf.variable_scope(name, reuse=reuse):
w1 = tf.get_variable('w1', initializer=xavier_init(parameters.image_size, parameters.encoder1))
b1 = tf.get_variable('b1', [parameters.encoder1], initializer=tf.zeros_initializer())
encoder1 = tf.nn.relu(tf.matmul(x_ph, w1) + b1)
w2 = tf.get_variable('w2', initializer=xavier_init(parameters.encoder1, parameters.encoder2))
b2 = tf.get_variable('b2', [parameters.encoder2], initializer=tf.zeros_initializer())
encoder2 = tf.nn.relu(tf.matmul(encoder1, w2) + b2)
w3 = tf.get_variable('w3', initializer=xavier_init(parameters.encoder2, parameters.z_nodes))
b3 = tf.get_variable('b3', [parameters.z_nodes], initializer=tf.zeros_initializer())
z_mean = tf.matmul(encoder2, w3) + b3
w4 = tf.get_variable('w4', initializer=xavier_init(parameters.encoder2, parameters.z_nodes))
b4 = tf.get_variable('b4', [parameters.z_nodes], initializer=tf.zeros_initializer())
z_log_sigma_sq = tf.matmul(encoder2, w4) + b4
return z_mean, z_log_sigma_sq
def generator(parameters, z, reuse, name):
"""译码过程,重构图像"""
with tf.variable_scope(name, reuse=reuse):
w1 = tf.get_variable('w1', initializer=xavier_init(parameters.z_nodes, parameters.decoder1))
b1 = tf.get_variable('b1', [parameters.decoder1], initializer=tf.zeros_initializer())
decoder1 = tf.nn.relu(tf.matmul(z, w1) + b1)
w2 = tf.get_variable('w2', initializer=xavier_init(parameters.decoder1, parameters.decoder2))
b2 = tf.get_variable('b2', [parameters.decoder2], initializer=tf.zeros_initializer())
decoder2 = tf.nn.relu(tf.matmul(decoder1, w2) + b2)
w3 = tf.get_variable('w3', initializer=xavier_init(parameters.decoder2, parameters.image_size))
b3 = tf.get_variable('b3', [parameters.image_size], initializer=tf.zeros_initializer())
logits = tf.matmul(decoder2, w3) + b3
x_reconst_mean = tf.nn.sigmoid(logits)
return logits, x_reconst_mean
def get_loss(x, logits, z_mean, z_log_sigma_sq):
"""定义损失函数,总损失等于重构损失与KL损失"""
reconstr_losses = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=x), 1)
latent_losses = -0.5 * tf.reduce_sum(1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq), 1)
total_loss = tf.reduce_mean(reconstr_losses + latent_losses, name='total_loss')
return total_loss
import tensorflow as tf
import model_def
from tensorflow.examples.tutorials.mnist import input_data
def main():
"""定义主函数main"""
file_path = 'E:\\Vae_test\\mnist_data'
mnist = input_data.read_data_sets(file_path, one_hot=True)
parameters = model_def.parameters()
x_ph = tf.placeholder(tf.float32, shape=[parameters.batch_size, parameters.image_size], name='x')
z_mean, z_log_sigma_sq = model_def.encoder(parameters, x_ph, reuse=False, name='encoder')
eps = tf.random_normal([parameters.batch_size, parameters.z_nodes] ,0 , 1, dtype=tf.float32)
z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
z = z_mean + eps * z_sigma
logits, x_reconstr_mean = model_def.generator(parameters, z, reuse=False, name='generator')
total_loss = model_def.get_loss(x_ph, logits, z_mean, z_log_sigma_sq)
learning_rate = 0.001
training_step = tf.Variable(0, trainable=False)
train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss, training_step)
z_ph = tf.placeholder(dtype=tf.float32, shape=[1,parameters.z_nodes])
z_example, _ = model_def.generator(parameters, z_ph, reuse=True, name='generator')
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver()
save_path = 'E:\\Vae_test\\model\\VaeModel.ckpt'
for i in range(parameters.max_step):
x_batch, _ = mnist.train.next_batch(parameters.batch_size)
_, loss = sess.run([train_step, total_loss], feed_dict={x_ph:x_batch})
if i % 1000 == 0:
print("After %d training steps, the loss is %.9f" % (i, loss))
saver.save(sess, save_path, training_step)
if __name__ == '__main__':
main()
import tensorflow as tf
import numpy as np
import model_def
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
def main():
file_path = 'E:\\Vae_test\\mnist_data'
mnist = input_data.read_data_sets(file_path, one_hot=True)
parameters = model_def.parameters()
x_ph = tf.placeholder(tf.float32, shape=[parameters.batch_size, parameters.image_size], name='x')
z_mean, z_log_sigma_sq = model_def.encoder(parameters, x_ph, reuse=False, name='encoder')
eps = tf.random_normal([parameters.batch_size, parameters.z_nodes] ,0 , 1, dtype=tf.float32)
z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
z = z_mean + eps * z_sigma
logits, x_reconstr_mean = model_def.generator(parameters, z, reuse=False, name='generator')
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('E:\\Vae_test\\model')
saver.restore(sess, ckpt.model_checkpoint_path)
steps = ckpt.model_checkpoint_path.split('-')[1]
print('Using ' + ckpt.model_checkpoint_path + ' with ' + steps)
test_data, _ = mnist.test.next_batch(batch_size=40)
x_mean = sess.run(logits, feed_dict={x_ph:test_data})
row = 2
line = 40
canvas = np.empty((28*row, 28*line))
for i in range(row):
for j in range(line):
if i == 0 :
canvas[i*28:(i+1)*28, j*28:(j+1)*28] = test_data[j].reshape(28,28)
else:
canvas[i*28:(i+1)*28, j*28:(j+1)*28] = x_mean[j].reshape(28,28)
plt.figure(figsize=(100, 200))
plt.imshow(canvas, cmap='Greys')
plt.show()
if __name__ == '__main__':
main()
import tensorflow as tf
import numpy as np
import model_def
import matplotlib.pyplot as plt
parameters = model_def.parameters()
z = tf.placeholder(dtype=tf.float32, shape=[None, parameters.z_nodes], name='z')
logits, x_reconstr_mean = model_def.generator(parameters, z, reuse=False, name='generator')
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('E:\\Vae_test\\model')
saver.restore(sess, ckpt.model_checkpoint_path)
model_name = ckpt.model_checkpoint_path.split('\\')[-1].split('.')[0]
steps = ckpt.model_checkpoint_path.split('-')[1]
print('Using '+ model_name + ' with ' + steps + ' training steps')
test_data = np.random.randn(20, 20)
x_mean = sess.run(logits, feed_dict={z:test_data})
row = 1
line = 20
canvas = np.empty((28*row, 28*line))
for i in range(row):
for j in range(line):
canvas[i*28:(i+1)*28, j*28:(j+1)*28] = x_mean[j].reshape(28,28)
plt.figure(figsize=(100, 200))
plt.imshow(canvas, cmap='Greys')
plt.show()
测试原始图像的重构效果
测试噪声的重构结果