import tensorflow as tf
# 加载数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp/data',one_hot=True)
Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
参数
batch_size = 100 # 每个batch的大小
n_batch = mnist.train.num_examples//batch_size # 训练集共包含多少个batch
定义计算图
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob = tf.placeholder(tf.float32) # dropout时,每个元素被保留的概率
W1 = tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
b1 = tf.Variable(tf.zeros([2000])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop = tf.nn.dropout(L1,keep_prob)
W2 = tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1))
b2 = tf.Variable(tf.zeros([2000])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2_drop = tf.nn.dropout(L2,keep_prob)
W3 = tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
b3 = tf.Variable(tf.zeros([1000])+0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop,W3)+b3)
L3_drop = tf.nn.dropout(L3,keep_prob)
W4 = tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
b4 = tf.Variable(tf.zeros([10])+0.1)
# 预测值
prediction = tf.nn.softmax(tf.matmul(L3_drop,W4)+b4)
# loss
# loss = tf.reduce_mean(tf.square(y-prediction)) # MSE
# cross entropy
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
# SGD
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
# 初始化
init = tf.global_variables_initializer()
# 分类结果
# tf.argmax(y,1):在axis=1,y中最大值的下标
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
# 准确率
# tf.cast:将bool类型的correct_prediction转换为tf.float32类型
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
训练
with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.5})
test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
print("Iter "+str(epoch)+",Testing accuracy "+str(test_acc)+",Training accuracy "+str(train_acc))
Iter 0,Testing accuracy 0.8683,Training accuracy 0.8614
Iter 1,Testing accuracy 0.9039,Training accuracy 0.8977818
Iter 2,Testing accuracy 0.9155,Training accuracy 0.9084
Iter 3,Testing accuracy 0.9178,Training accuracy 0.9128909
Iter 4,Testing accuracy 0.9223,Training accuracy 0.9183818
Iter 5,Testing accuracy 0.9246,Training accuracy 0.9210727
Iter 6,Testing accuracy 0.9287,Training accuracy 0.9246909
Iter 7,Testing accuracy 0.9314,Training accuracy 0.9265091
Iter 8,Testing accuracy 0.9316,Training accuracy 0.9280546
Iter 9,Testing accuracy 0.935,Training accuracy 0.9311818
Iter 10,Testing accuracy 0.9351,Training accuracy 0.9314182
Iter 11,Testing accuracy 0.9349,Training accuracy 0.93396366
Iter 12,Testing accuracy 0.9376,Training accuracy 0.93474543
Iter 13,Testing accuracy 0.938,Training accuracy 0.9373818
Iter 14,Testing accuracy 0.9395,Training accuracy 0.9387818
Iter 15,Testing accuracy 0.9389,Training accuracy 0.9401091
Iter 16,Testing accuracy 0.9419,Training accuracy 0.94192725
Iter 17,Testing accuracy 0.9421,Training accuracy 0.9425273
Iter 18,Testing accuracy 0.9439,Training accuracy 0.94354546
Iter 19,Testing accuracy 0.9443,Training accuracy 0.9446727
Iter 20,Testing accuracy 0.9435,Training accuracy 0.9451454