python实现自编码器autoencode

# -*- coding: utf-8 -*-
"""
Created on Sun Sep  3 13:48:19 2017

@author: piaodexin
"""
from __future__ import division, print_function, absolute_import
import tensorflow as tf
from  tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
mnist=input_data.read_data_sets('E:\\mnist',one_hot=True)
'''
定义输入层 (28,28) =784
第一层隐含层500个
第二层100个
第三层500
输出层784 
这是因为自编码就是希望神经网络自己学习图片特征,然后再用学习到的特征去组成原始图片,所以最后
输出层是(28,28)=784
'''
input_n=784
hidden1_n=500
hidden2_n=100
hidden3_n=500
output_n=784

learn_rate=0.01
batch_size=100
train_epoch=30000

x=tf.placeholder(tf.float32,[None,input_n])
y=tf.placeholder(tf.float32,[None,input_n])

weights1=tf.Variable(tf.truncated_normal([input_n,hidden1_n],stddev=0.1))
bias1=tf.Variable(tf.constant(0.1,shape=[hidden1_n]))

weights2=tf.Variable(tf.truncated_normal([hidden1_n,hidden2_n],stddev=0.1))
bias2=tf.Variable(tf.constant(0.1,shape=[hidden2_n]))

weights3=tf.Variable(tf.truncated_normal([hidden2_n,hidden3_n],stddev=0.1))
bias3=tf.Variable(tf.constant(0.1,shape=[hidden3_n]))

weights4=tf.Variable(tf.truncated_normal([hidden3_n,output_n],stddev=0.1))
bias4=tf.Variable(tf.constant(0.1,shape=[output_n]))

def get_result(x,weights1,bias1,weights2,bias2,weights3,bias3,weights4,bias4):
    a1=tf.nn.sigmoid(tf.matmul(x,weights1)+bias1)
    a2=tf.nn.sigmoid(tf.matmul(a1,weights2)+bias2)
    a3=tf.nn.sigmoid(tf.matmul(a2,weights3)+bias3)
    y_=tf.nn.sigmoid(tf.matmul(a3,weights4)+bias4)
    return y_
'''
当我一步一步求y_的时候,却出现错误,只能用函数,不知道为什么
'''
y_=get_result(x,weights1,bias1,weights2,bias2,weights3,bias3,weights4,bias4)


loss=tf.reduce_mean(tf.pow(y_-y,2))

train_op=tf.train.RMSPropOptimizer(learn_rate).minimize(loss)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(train_epoch):
        xs,ys=mnist.train.next_batch(batch_size)
        if i%1000 == 0:
            print('epoch:',i)
            print('loss:',sess.run(loss,feed_dict={x:xs,y:xs}))
        sess.run(train_op,feed_dict={x:xs,y:xs})
    xt=mnist.test.images[:5]
    yt=xt 
    encode_decode=sess.run(y_,feed_dict={x:xt,y:yt})
    f,a =plt.subplots(2,5,figsize=(10,2))
    for i in range(5):
        a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)))
        a[1][i].imshow(np.reshape(encode_decode[i],(28,28)))
    f.show()
    plt.draw()
#结果展示:上面是原图片,下面是自编码学习到的


你可能感兴趣的:(机器学习,Python)