走进tensorflow第二步——MNIST手写数字分类问题(基础篇)

继续跟着官方文档走,可这次遇到的插曲有点多,整了一下午。之前用anaconda安装的tensorflow,用spyder运行时出现了意料之外的错误,查资料显示版本问题,各种更新,pip不好使,最后还是用conda安装的,有点凌乱,反正最后不知咋地就整好了,神奇。。

开始吧,大家可以参考官方文档,有略微修改:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-c1ov28so.html

上述链接是w3school的,如果心里有坎也可以看官方的,一样的:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html

当时读文档的时候发现了错误,提交笔记一直没通过,难道有啥黑幕。。大家可以注意一下那个softmax等式,很明显的错误,第二个图(矩阵格式)是对的。

ok,开始吧:

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

相比于源码做了以下修改:

1.由于文档给的自动下载MNIST数据集的代码链接一直打不开,只能去官网下了源码,格式有点新奇,这个可以去官网好好看一下,不过对本实验没啥影响,下载四个文件,不用解压,建议放到代码所属文件夹,我的是这个样子:

走进tensorflow第二步——MNIST手写数字分类问题(基础篇)_第1张图片

2.文档中给的是用tf.initialize_all_variables()来初始化,但这个已经废弃了,需要使用tf.global_variables_initializer()来替换掉之前的,当然你要是使用老版的也没问题,不过会有警告:

initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.

看下最后的效果:

走进tensorflow第二步——MNIST手写数字分类问题(基础篇)_第2张图片

这个准确率的确不怎么高,虽然已经90+了……接下来就搞一下卷积吧

你可能感兴趣的:(AI)