import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
#每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
#定义两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
#创建一个简单的神经网络
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#交叉熵
#loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
运行结果:
Iter 0,Testing Accuracy 0.8323
Iter 1,Testing Accuracy 0.8712
Iter 2,Testing Accuracy 0.8811
Iter 3,Testing Accuracy 0.8872
Iter 4,Testing Accuracy 0.8946
Iter 5,Testing Accuracy 0.8969
Iter 6,Testing Accuracy 0.8989
Iter 7,Testing Accuracy 0.9014
Iter 8,Testing Accuracy 0.9036
Iter 9,Testing Accuracy 0.9062
Iter 10,Testing Accuracy 0.9064
Iter 11,Testing Accuracy 0.9074
Iter 12,Testing Accuracy 0.9083
Iter 13,Testing Accuracy 0.9094
Iter 14,Testing Accuracy 0.9102
Iter 15,Testing Accuracy 0.9115
Iter 16,Testing Accuracy 0.9123
Iter 17,Testing Accuracy 0.9116
Iter 18,Testing Accuracy 0.9126
Iter 19,Testing Accuracy 0.9129
Iter 20,Testing Accuracy 0.9136
简单介绍下MNIST数据集:
MNIST数据集的标签是介于0-9的数字,我们要把标签转化为“one-hot vectors”,一个one-hot向量除了某一位数字是1以外,其余维度数字都是0,比如标签0,([1,0,0,0,0,0,0,0,0])。mnist.train.labels是一个[60000,10]的数字矩阵。