Tensorflow03:搭建Auto-encoder和decoder

【网络设计】
采用全连接网络:
3层编码,784->256->128
3层解码,128->256->784

输入:mnist手写图片
输出:由网络还原出来的图片
目标:还原度越高越好

因此我们可以总结出,最简单的Auto-encoder和decoder其实就是特殊结构的全连接神经网络

【代码展示】

#定义数据
mnist = input_data.read_data_sets('./mnist', one_hot=True)
n_input=784
n_hidden_1=256
n_hidden_2=128

#定义批个数和学习速率,这些决定了学习成果
batch_size=100
lr=0.001
training_epoches=200
display_epoches=10

total_batch=mnist.count()/batch_size
#输入,一个batch的图片
tf_x=tf.placeholder(tf.float32,shape=[None,28*28])
examples_to_show=7

#定义网络参数
weights={
    'encoder_w1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),
    'encoder_w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),

    'decoder_w1':tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),
    'decoder_w2': tf.Variable(tf.random_normal([n_hidden_1,n_input]))

}
biases={
    'encoder_b1':tf.Variable(tf.random_normal([n_hidden_1])),
    'encoder_b2':tf.Variable(tf.random_normal([n_hidden_2])),

    'decoder_b1':tf.Variable(tf.random_normal([n_hidden_2])),
    'decoder_b2': tf.Variable(tf.random_normal([n_hidden_1,n_input]))
}

#定义网络的运算和连接方式
def encoder(x):
    layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['encoder_w1']),biases['encoder_b1']))
    layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['encoder_w2']),biases['encoder_b2']))
    return layer_2

def decoder(x):
    layer_1=tf.nn.sigmoid(tf.add(tf.matmul(x,weights['decoder_w1']),biases['decoder_b1']))
    layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['decoder_w2']),biases['decoder_b2']))
    return layer_2

encoder_op=encoder(tf_x)
decoder_op=decoder(encoder_op)

y_pred=decoder_op
y_true=tf_x

#定义学习方式
cost=tf.reduce_mean(tf.pow(y_true-y_pred,2))
optimizer=tf.train.AdamOptimizer(lr).minimize(cost)

init=tf.initialize_all_variables()

#训练
with tf.Session()as sess:
    sess.run(init)
    total_batch
    for i in range(training_epoches):
        for j in range(total_batch):
            batch_x, batch_y = mnist.train.nextbatch(batch_size)
            _,c=sess.run([cost,optimizer],feed_dict={tf_x:batch_x})
            if(j%display_epoches==0):
                print("Epoch:%04d"%(j+1),"cost=","{:,%.9f}".format(c))
    print("Optimize Finished!")
    encode_decode=sess.run(y_pred,feed_dict={tf_x:mnist.test.images[:examples_to_show]})
    f,a=plt.subplots(2,10,figsize=(10,2))
    for i in range(examples_to_show):
        a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)))
        a[1][i].imshow(np.reshape(encode_decode[i],(28,28)))
    plt.show()

【注意】
1、采用AdamOptimizer,效果最好
2、解码和编码网络架构是对称的
3、learningRate(lr)是个很重要的参数

你可能感兴趣的:(Tensorflow03:搭建Auto-encoder和decoder)