Tensorflow入门之Mnist数字识别

# -*- coding: utf-8 -*-
"""
Created on Sun Jul 22 11:51:44 2018

@author: KK
"""
#用于下载mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

import tensorflow as tf
sess = tf.InteractiveSession()

#为输入图像和目标输出类别创建节点,来开始构建计算图
#x为输入数据,y_为实际标签
#None表示其值大小不定,在这里作为第一个维度值,用以指代batch的大小,意即x的数量不定
x = tf.placeholder("float",shape=[None,784])
y_ = tf.placeholder("float",shape=[None,10])

#为模型定义权重和偏置
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#初始化变量
sess.run(tf.initialize_all_variables())

#实现回归模型,计算每个分类的softmax概率值
#y为预测标签
y = tf.nn.softmax(tf.matmul(x,W)+b)

#交叉熵(损失函数)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

#梯度下降算法,步长为0.01
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

for i in range(1000):
    #选取batch
    batch = mnist.train.next_batch(50)
    #执行train_step操作,即梯度下降算法
    train_step.run(feed_dict={x:batch[0],y_:batch[1]})

#argmax给出某个tensor对象在某一维上的其数据最大值所在的索引值,返回值为布尔数组
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

#将布尔值转换为浮点数来代表对、错,然后取平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print ('the accuracy is:%f'%(accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels})))

 

你可能感兴趣的:(Tensorflow学习)