tensorflow学习笔记(四):利用BP手写体(MNIST)识别


文章目录

  • 一、MNIST数据集
    • 1.简介
    • 2.特点
  • 二、问题描述
  • 三、项目实践
  • 四、网络优化


一、MNIST数据集

1.简介

MNIST数据集是一个手写体数据集,简单说就是一堆这样东西:
tensorflow学习笔记(四):利用BP手写体(MNIST)识别_第1张图片
MNIST的官网地址是MNIST; 通过阅读官网可以知道,这个数据集由四部分组成,分别是:
在这里插入图片描述
也就是一个训练图片集,一个训练标签集,一个测试图片集,一个测试标签集;我们可以看出这个其实并不是普通的文本文件或是图片文件,而是一个压缩文件,下载并解压出来,我们看到的是二进制文件,其中训练图片集的内容部分如此:
tensorflow学习笔记(四):利用BP手写体(MNIST)识别_第2张图片

2.特点

1)、数据集分为两部分:训练集和测试集

训练集:包含60000行的训练数据集

  • List item Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)

测试集:包含10000行的测试数据集

  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含
    10,000 个样本)
  • List item Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

2)、每张图片包含28x28个像素,将这一个数组展成一个向量,长度是28x28=784.因此在MNIST训练数据集mnist.train.images是一个形状为[60000,784]的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。图片中的某个像素的强度值介于0-1之间。
tensorflow学习笔记(四):利用BP手写体(MNIST)识别_第3张图片
tensorflow学习笔记(四):利用BP手写体(MNIST)识别_第4张图片
3)、MNIST数据集的标签是介于0-9的数字,需要把标签转化成"one-hot vector"。一个one-hot向量除了某一位数字是1以外,其余的维度数字都为0,比如标签0将表示为 ( [ 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ] ) ([1,0,0,0,0,0,0,0,0,0]) ([1,0,0,0,0,0,0,0,0,0]),标签3将表示为 ( [ 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ] ) ([0,0,0,1,0,0,0,0,0,0]) ([0,0,0,1,0,0,0,0,0,0])

3)、因此,mnist.train.lables是一个[60000,10]的数字矩阵。
tensorflow学习笔记(四):利用BP手写体(MNIST)识别_第5张图片

二、问题描述

利用Tensorflow搭建BP网络,利用MNIST的训练集训练网络,并用测试集去测试,最后给出准确率。

三、项目实践

代码:


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#载入数据集
mnist=input_data.read_data_sets("MINST_data",one_hot=True)  #将标签转化为one_hot形式,然后数据集命名为MNIST_data

#每个批次的大小
batch_size=100
#计算一共有多少个批次
n_batch=mnist.train.num_examples//batch_size   #//表示整除后取整

#定义两个placeholder
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])

#创建一个简单的神经网络
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
prediction=tf.nn.softmax(tf.matmul(x,W)+b)

#二次代价函数
 loss=tf.reduce_mean(tf.square(y-prediction))

#使用梯度下降
train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#初始化变量
init=tf.global_variables_initializer()

#结果存放在一个布尔型的列表中
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) #tf.equal比较括号中中的两个参数是否一致,如果一致,返回True,否则返回Flasetf.equal
                                                                    #tf.argmax(y,1)返回一维张量中最大的值所在的位置,其中:1表示按行索取
#求准确率
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))  #cast将布尔型的数据转换成float32型的数据;reduce_mean求平均值

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):            #该层循环表示所有的图片训练21次
        for batch in range(n_batch):   #该层循环表示将所有的图片训练一次
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)  #图片的数据保存在batch_xs中,图片的标签保存在batch_ys中
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})

        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter"+str(epoch)+",Testing Accuracy"+str(acc))

结果:

Iter0,Testing Accuracy0.8234
Iter1,Testing Accuracy0.8898
Iter2,Testing Accuracy0.9013
Iter3,Testing Accuracy0.9056
Iter4,Testing Accuracy0.9087
Iter5,Testing Accuracy0.9101
Iter6,Testing Accuracy0.911
Iter7,Testing Accuracy0.9139
Iter8,Testing Accuracy0.9149
Iter9,Testing Accuracy0.9153
Iter10,Testing Accuracy0.9181
Iter11,Testing Accuracy0.9172
Iter12,Testing Accuracy0.9187
Iter13,Testing Accuracy0.9196
Iter14,Testing Accuracy0.9195
Iter15,Testing Accuracy0.9195
Iter16,Testing Accuracy0.9204
Iter17,Testing Accuracy0.9211
Iter18,Testing Accuracy0.9205
Iter19,Testing Accuracy0.9209
Iter20,Testing Accuracy0.9209

四、网络优化

1)改变批次的大小


#每个批次的大小
batch_size=100

这里可以再适当的增大批次;

2)、变量初始化

#创建一个简单的神经网络
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))

初始化时,这里初始化的全是0,生成正态分布的效果较好,会加速收敛效果;

3)、增加隐藏层

这里只用了一层网络,最后准确率就达到了92.09%,可以适当增加一层网络。注意,不要增加的太多,由于数据集比较简单,如果网络过于复杂会过拟合,此时会需要正则化、dropout等其他方法来抑制过拟合。

4)、修改代价函数

这里使用的是二次代价函数,可以修改为交叉熵代价函数,


#对数似然函数
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))

你可能感兴趣的:(tensorflow实战)