这是自己的第一篇博客,写的不好请见谅,就是想把自己这段时间所学的相关tensorflow的相关东西做个总结,帮助自己记忆,也希望能对别人有点帮助。
通过Tensorflow做的关于Mnist的数字识别,具体的完整流程请见极客网的流程。我只想记录一下自己的实现的CNN部分。
这一部分没什么要说的,导入Tensorflow
#coding=utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
用于对变量赋初值,为了避免0梯度问题,我们在设置偏置值时使用不为0的小常数,这里取0.1。在tensorflow中的constant用于指定常量值,0.1是常数的具体取值,shape是数据格式。当value = 1,shape=[1,2]时,[1,1]为初值结果。truncated_normal决定了产生的数据的随机模型
def weight_variable(shape):
initial = tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1,shape=shape)
# 0.1是常数的值,shape是数据格式
return tf.Variable(initial)
def conv2d(x,W):
#conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,data_format=None, name=None)
return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding="SAME") #由于采用的是SAME,所以用5x5卷积以后还是28X28
def max_pool_2x2(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
x = tf.placeholder("float",[None,784])
y_ = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
W_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32])
x_img = tf.reshape(x,[-1,28,28,1])
h_conv1 = tf.nn.relu(conv2d(x_img,W_conv1)+b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
W_conv2 = weight_variable([5,5,32,64]) #理解此处为什么是32
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1,W_conv2)+b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
W_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1) + b_fc1)
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)#?
W_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10])
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2)
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())
for i in range(20000):
batch = mnist.train.next_batch(50)
if i % 100 == 0:
train_accuracy = accuracy.eval(feed_dict={
x:batch[0], y_:batch[1], keep_prob:1.0})
print ("step %d, training accuracy %g" % (i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob:0.5})
print ("test accuracy %g" % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob:1.0}))