VAE手写字重构项目实现

# Vae模型的定义
# 2019.11.6

import tensorflow as tf
import numpy as np

class parameters():
    """一个定义基础参数的类"""

    def __init__(self):
        self.encoder1 = 500 # encoder第一层节点数
        self.encoder2 = 500 # encoder第二层节点数
        self.decoder1 = 500 # decoder第一层节点数
        self.decoder2 = 500 # decoder第二层节点数
        self.image_size = 784 # mnist数据集尺寸为28*28=784
        self.batch_size = 40 # 一个batch的数据
        self.max_step = 30000 # 最大训练步数
        self.z_nodes = 20 # z分布节点数


def xavier_init(fan_in, fan_out, constant=1):
    """ Xavier initialization of network weights"""
    # https://stackoverflow.com/questions/33640581/how-to-do-xavier-initialization-on-tensorflow

    low = -constant*np.sqrt(6.0 / (fan_in + fan_out))
    high = constant*np.sqrt(6.0 / (fan_in + fan_out))
    return tf.random_uniform((fan_in, fan_out), minval=low, maxval=high, dtype=tf.float32)


def encoder(parameters, x_ph, reuse, name):
    """编码过程,输出一个关于Z分布的均值和方差的对数形式"""

    with tf.variable_scope(name, reuse=reuse):
        w1 = tf.get_variable('w1', initializer=xavier_init(parameters.image_size, parameters.encoder1))
        b1 = tf.get_variable('b1', [parameters.encoder1], initializer=tf.zeros_initializer())
        encoder1 = tf.nn.relu(tf.matmul(x_ph, w1) + b1)

        w2 = tf.get_variable('w2', initializer=xavier_init(parameters.encoder1, parameters.encoder2))
        b2 = tf.get_variable('b2', [parameters.encoder2], initializer=tf.zeros_initializer())
        encoder2 = tf.nn.relu(tf.matmul(encoder1, w2) + b2)

        w3 = tf.get_variable('w3', initializer=xavier_init(parameters.encoder2, parameters.z_nodes))
        b3 = tf.get_variable('b3', [parameters.z_nodes], initializer=tf.zeros_initializer())
        z_mean = tf.matmul(encoder2, w3) + b3

        w4 = tf.get_variable('w4', initializer=xavier_init(parameters.encoder2, parameters.z_nodes))
        b4 = tf.get_variable('b4', [parameters.z_nodes], initializer=tf.zeros_initializer())
        z_log_sigma_sq = tf.matmul(encoder2, w4) + b4

        return z_mean, z_log_sigma_sq


def generator(parameters, z, reuse, name):
    """译码过程,重构图像"""

    with tf.variable_scope(name, reuse=reuse):
        w1 = tf.get_variable('w1', initializer=xavier_init(parameters.z_nodes, parameters.decoder1))
        b1 = tf.get_variable('b1', [parameters.decoder1], initializer=tf.zeros_initializer())
        decoder1 = tf.nn.relu(tf.matmul(z, w1) + b1)

        w2 = tf.get_variable('w2', initializer=xavier_init(parameters.decoder1, parameters.decoder2))
        b2 = tf.get_variable('b2', [parameters.decoder2], initializer=tf.zeros_initializer())
        decoder2 = tf.nn.relu(tf.matmul(decoder1, w2) + b2)

        w3 = tf.get_variable('w3', initializer=xavier_init(parameters.decoder2, parameters.image_size))
        b3 = tf.get_variable('b3', [parameters.image_size], initializer=tf.zeros_initializer())
        logits = tf.matmul(decoder2, w3) + b3
        x_reconst_mean = tf.nn.sigmoid(logits)

        return logits, x_reconst_mean

def get_loss(x, logits, z_mean, z_log_sigma_sq):
    """定义损失函数,总损失等于重构损失与KL损失"""

    reconstr_losses = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=x), 1)
    latent_losses = -0.5 * tf.reduce_sum(1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq), 1)
    total_loss = tf.reduce_mean(reconstr_losses + latent_losses, name='total_loss') # 每张图片的总损失

    return total_loss
# Vae模型的主函数main
# 2019.11.7

import tensorflow as tf
import model_def
from tensorflow.examples.tutorials.mnist import input_data

def main():
    """定义主函数main"""

    file_path = 'E:\\Vae_test\\mnist_data'
    mnist = input_data.read_data_sets(file_path, one_hot=True)

    parameters = model_def.parameters()
    x_ph = tf.placeholder(tf.float32, shape=[parameters.batch_size, parameters.image_size], name='x')

    # 计算均值和方差的对数形式
    z_mean, z_log_sigma_sq = model_def.encoder(parameters, x_ph, reuse=False, name='encoder')

    # 引入eps~N(0,1)和z_mean,z_sigma形成新的分布,并从该分布采样得到z
    eps = tf.random_normal([parameters.batch_size, parameters.z_nodes] ,0 , 1, dtype=tf.float32)
    z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
    z = z_mean + eps * z_sigma

    # 将z分布载入译码网络(generator)重构图像
    logits, x_reconstr_mean = model_def.generator(parameters, z, reuse=False, name='generator')

    # 得到损失函数
    total_loss = model_def.get_loss(x_ph, logits, z_mean, z_log_sigma_sq)

    # 定义学习率
    learning_rate = 0.001

    # 定义当前进行的轮数
    training_step = tf.Variable(0, trainable=False)

    # 使用Adam优化器
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss, training_step)

    z_ph = tf.placeholder(dtype=tf.float32, shape=[1,parameters.z_nodes])
    z_example, _ = model_def.generator(parameters, z_ph, reuse=True, name='generator')

    with tf.Session() as sess:

        tf.global_variables_initializer().run()
        saver = tf.train.Saver()
        save_path = 'E:\\Vae_test\\model\\VaeModel.ckpt'

        for i in range(parameters.max_step):
            x_batch, _ = mnist.train.next_batch(parameters.batch_size)
            _, loss = sess.run([train_step, total_loss], feed_dict={x_ph:x_batch})

            if i % 1000 == 0:
                print("After %d training steps, the loss is %.9f" % (i, loss))
                saver.save(sess, save_path, training_step)

if __name__ == '__main__':
    main()
# 用于测试对原始图像进行重构
# 2019.11.8

import tensorflow as tf
import numpy as np
import model_def
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

def main():

    file_path = 'E:\\Vae_test\\mnist_data'
    mnist = input_data.read_data_sets(file_path, one_hot=True)

    parameters = model_def.parameters()
    x_ph = tf.placeholder(tf.float32, shape=[parameters.batch_size, parameters.image_size], name='x')

    z_mean, z_log_sigma_sq = model_def.encoder(parameters, x_ph, reuse=False, name='encoder')

    eps = tf.random_normal([parameters.batch_size, parameters.z_nodes] ,0 , 1, dtype=tf.float32)
    z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
    z = z_mean + eps * z_sigma

    logits, x_reconstr_mean = model_def.generator(parameters, z, reuse=False, name='generator')

    saver = tf.train.Saver()

    with tf.Session() as sess:

        ckpt = tf.train.get_checkpoint_state('E:\\Vae_test\\model')
        saver.restore(sess, ckpt.model_checkpoint_path)

        steps = ckpt.model_checkpoint_path.split('-')[1]
        print('Using ' + ckpt.model_checkpoint_path + ' with ' + steps)

        test_data, _ = mnist.test.next_batch(batch_size=40)
        x_mean = sess.run(logits, feed_dict={x_ph:test_data})

        row = 2
        line = 40

        canvas = np.empty((28*row, 28*line))

        for i in range(row):
            for j in range(line):
                if i == 0 :
                    canvas[i*28:(i+1)*28, j*28:(j+1)*28] = test_data[j].reshape(28,28)
                else:
                    canvas[i*28:(i+1)*28, j*28:(j+1)*28] = x_mean[j].reshape(28,28)

        plt.figure(figsize=(100, 200))
        plt.imshow(canvas, cmap='Greys')
        plt.show()

if __name__ == '__main__':
    main()
# 用于对随机产生的噪声进行重构图像
# 2019.11.8

import tensorflow as tf
import numpy as np
import model_def
import matplotlib.pyplot as plt


parameters = model_def.parameters()

z = tf.placeholder(dtype=tf.float32, shape=[None, parameters.z_nodes], name='z')

logits, x_reconstr_mean = model_def.generator(parameters, z, reuse=False, name='generator')

saver = tf.train.Saver()

with tf.Session() as sess:

    ckpt = tf.train.get_checkpoint_state('E:\\Vae_test\\model')
    saver.restore(sess, ckpt.model_checkpoint_path)

    model_name = ckpt.model_checkpoint_path.split('\\')[-1].split('.')[0]
    steps = ckpt.model_checkpoint_path.split('-')[1]
    print('Using '+ model_name + ' with ' + steps + ' training steps')

    test_data = np.random.randn(20, 20)
    x_mean = sess.run(logits, feed_dict={z:test_data})

    row = 1
    line = 20
    canvas = np.empty((28*row, 28*line))

    for i in range(row):
        for j in range(line):
            canvas[i*28:(i+1)*28, j*28:(j+1)*28] = x_mean[j].reshape(28,28)

    plt.figure(figsize=(100, 200))
    plt.imshow(canvas, cmap='Greys')
    plt.show()

测试原始图像的重构效果
在这里插入图片描述
测试噪声的重构结果
在这里插入图片描述

你可能感兴趣的:(TensorFlow,1.x学习笔记)