机器学习笔记 tensorflow mnist上实现CNN网络

mnist上的一个普通cnn例子,采用两层卷积和池化层加一层全连接,为了防止过拟合在全连接层用了dropout,是一个十分简单的例子

import tensorflow as tf
import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

x = tf.placeholder("float", [None, 784])
y_ = tf.placeholder("float", [None, 10])

x_image = tf.reshape(x, [-1, 28, 28, 1])
# layer_one
filter_one = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1))
filter_one_bias = tf.Variable(tf.zeros([32])+0.1)

filter_one_h = tf.nn.relu(tf.nn.conv2d(x_image, filter_one, strides=[1, 1, 1, 1], padding="SAME")+filter_one_bias)
filter_one_out = tf.nn.max_pool(filter_one_h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
# layer_two

filter_two = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1))
filter_two_bias = tf.Variable(tf.zeros([64])+0.1)

filter_two_h = tf.nn.relu(tf.nn.conv2d(filter_one_out, filter_two, strides=[1, 1, 1, 1], padding="SAME")+filter_two_bias)
filter_two_out = tf.nn.max_pool(filter_two_h, ksize=[1, 2, 2,1], strides=[1, 2, 2, 1], padding="SAME")

# full_connect

full_connect_w = tf.Variable(tf.truncated_normal([7*7*64, 1024], stddev=0.1))
full_connect_bias = tf.Variable(tf.zeros([1024])+0.1)

full_connect_out = tf.nn.relu(tf.matmul(tf.reshape(filter_two_out, [-1, 7*7*64]), full_connect_w)+full_connect_bias)

# drop_out
keep_drop = tf.placeholder("float")
drop_out = tf.nn.dropout(full_connect_out, keep_drop)

# softmax_loss
w = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
b = tf.Variable(tf.zeros([10])+0.1)

loss = tf.nn.softmax(tf.matmul(drop_out, w)+b)

cross_entropy = -tf.reduce_sum(y_ * tf.log(loss)) #计算交叉熵
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #使用adam优化器来以0.0001的学习率来进行微调
correct_prediction = tf.equal(tf.argmax(loss,1), tf.argmax(y_,1)) #判断预测标签和实际标签是否匹配
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))

sess = tf.Session() #启动创建的模型

sess.run(tf.global_variables_initializer()) #初始化变量

for i in range(5000): #开始训练模型,循环训练5000次
    batch = mnist.train.next_batch(50) #batch大小设置为50
    if i % 100 == 0:
        train_accuracy = accuracy.eval(session = sess,
                                       feed_dict = {x:batch[0], y_:batch[1], keep_drop:1.0})
        print("step %d, train_accuracy %g" %(i, train_accuracy))
    train_step.run(session = sess, feed_dict = {x:batch[0], y_:batch[1],
                   keep_drop:0.5}) #神经元输出保持不变的概率 keep_prob 为0.5

print("test accuracy %g" %accuracy.eval(session = sess,
      feed_dict = {x:mnist.test.images, y_:mnist.test.labels,
                   keep_drop:1.0})) #神经元输出保持不变的概率 keep_prob 为 1,即不变,一直保持输出

机器学习笔记 tensorflow mnist上实现CNN网络_第1张图片

你可能感兴趣的:(机器学习)