MNIST无监督学习-自编码器

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt

mnist=input_data.read_data_sets("mnist",one_hot=True);
#设置超参数
learninig_rate=0.1;
training_epochs=20;
batch_size=100;
display_step=1;
#网络的参数
n_input=784;
n_hidden_1=256;
n_hidden_2=128;
#自动编码器设置好以后用来测试效果的图片的数量
examples_to_show=10;

X=tf.placeholder("float",[None,n_input]);
weights={
    "encoder_h1":tf.Variable(tf.random_normal([n_input,n_hidden_1])),
    "encoder_h2":tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),
    "decoder_h1":tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),
    "decoder_h2":tf.Variable(tf.random_normal([n_hidden_1,n_input]))
}
biases={
    "encoder_b1":tf.Variable(tf.random_normal([n_hidden_1])),
    "encoder_b2":tf.Variable(tf.random_normal([n_hidden_2])),
    "decoder_b1":tf.Variable(tf.random_normal([n_hidden_1])),
    "decoder_b2":tf.Variable(tf.random_normal([n_input]))
}

def encoder(x):
    layer_1=tf.nn.sigmoid(tf.matmul(x,weights["encoder_h1"]+biases["encoder_b1"]));
    layer_2=tf.nn.sigmoid(tf.matmul(layer_1,weights["encoder_h2"]+biases["encoder_b2"]));
    return layer_2;

def decoder(x):
    layer_1=tf.nn.sigmoid(tf.matmul(x,weights["decoder_h1"]+biases["decoder_b1"]));
    layer_2=tf.nn.sigmoid(tf.matmul(layer_1,weights["decoder_h2"]+biases["decoder_b2"]));
    return layer_2;

encoder_op=encoder(X);
decoder_op=decoder(encoder_op);

y_pred=decoder_op;
y_true=X;

cost=tf.reduce_mean(tf.square(y_true-y_pred));
optimizer=tf.train.GradientDescentOptimizer(learninig_rate).minimize(cost);

init=tf.global_variables_initializer();
with tf.Session() as  sess:
    sess.run(init);
    total_batch=int(mnist.train.num_examples//batch_size);
    for epoch in range(training_epochs):
        for i in range(total_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size);
            _,c=sess.run([optimizer,cost],feed_dict={X:batch_xs})
        if epoch % display_step==0:
            print("Epoch:","%04d"%(epoch+1),"cost=","{:.9f}".format(c));
        
    print("Optimizer Finished!!!");
    
    #对测试集应用训练好的自动编码网络
    encode_decode=sess.run(y_pred,feed_dict={X:mnist.test.images[:examples_to_show]});
    #比较测试集原始图像和自动编码器的重建结果
    f,a=plt.subplots(2,10,figsize=(10,2));
    for i in range(examples_to_show):
        a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)));
        a[1][i].imshow(np.reshape(encode_decode[i],(28,28)));
    f.show();
    plt.draw();

你可能感兴趣的:(深度学习通用代码)