手写体数字识别(Python+TensorFlow)

注意:该作者博客已迁移至https://buxianshan.xyz

先看结果

在MNIST数据集10000张测试图片上的正确率
在这里插入图片描述
测试手写数字图片(20张)
原图
手写体数字识别(Python+TensorFlow)_第1张图片
测试结果
手写体数字识别(Python+TensorFlow)_第2张图片

源文件下载:

  • CSDN下载https://download.csdn.net/download/qq_43479622/11227413(需要5C币)
  • 没有C币的也可以到GitHub下载https://github.com/BuXianShan/Handwritten-Numeral-Recognition

声明:

本文大部分程序参考《TensorFlow实战Google深度学习框架》,很适合深度学习入门的书籍。

解压文件打开后如图

手写体数字识别(Python+TensorFlow)_第3张图片
__pycache__文件夹是Python自动生成的,不用管它,想了解它可以参考https://blog.csdn.net/yitiaodashu/article/details/79023987

MNIST_model文件夹保存了已经训练30000次的模型

picture文件夹存放的是自己手写数字的图片

需要的安装的Python模块

  • tensorflow
  • opencv-python
  • pillow

(这里我建议全部使用pip安装,用国内镜像下载特别快,如果以后计算量大需要使用gpu版本的tensorflow再重新安装gpu版本的。opencv-python和pillow都是关于图像处理的,只有app.py文件使用到了。)

关于MNIST数据集

MNIST是深度学习的经典入门demo,他是由6万张训练图片和1万张测试图片构成的,每张图片都是28*28的灰度图,像素取值为0~1。这些图片是采集的不同的人手写从0到9的数字。TensorFlow将这个数据集和相关操作封装到了库中,每一张图片是一个长度为784的一维数组。

from tensorflow.examples.tutorials.mnist import input_data

便会自动下载封装好的数据集。

mnist_inference.py文件定义了前向传播过程以及神经网络的参数

三层全连接网络结构,通过加入隐藏层实现了多层网络结构。

mnist_train.py定义了神经网络的训练过程

运行mnist_train.py文件便会开始训练模型,MNIST_model文件已经有训练好的模型,你也可以删掉或修改然后重新训练。

mnist_eval.py文件定义了测试过程

运行mnist_eval.py文件就是计算在mnist数据集上测试1万张图片的正确率。
在这里插入图片描述

app.py文件实现了测试自己手写数字的图片

在picture文件夹保存要测试的图片,运行app.py文件即可输出测试结果

以上都是我测试过的文件,可以使用,下面记录了我遇到的困难

我遇到的困难

困难1:不知道模型训练好了怎么测试自己手写的图片

跟着书上开始做的时候,mnist_inference.py、mnist_train.py 和 mnist_eval.py这三个文件已经可以实现训练模型和测试正确率。但是由于这三个文件使用的mnist数据集,我还不知道数据集到底长什么样子,也不知道模型训练好了怎么测试自己手写的图片。就是感觉好像做了一个很厉害的东西,但是不知道怎么用的感觉。
然后把模型搭建的过程重新梳理了一遍,才知道从何下手。输入节点是长度为784的数组,所以得把我的图片转化为长度是784的数组,才能输入到模型里,才能得到结果。
代码请看app.py里的image_prepare()函数,通过使用图像处理库PIL把图片转化为灰度图并且修改尺寸为28*28,然后转化为数组

困难2:测试自己手写图片的正确率太低

在mnist测试数据集上的正确率有98.52%,而测试自己手写数字的正确率几乎为0,大部分数字都被识别成8。
手写体数字识别(Python+TensorFlow)_第4张图片
通过不断测试我总结了以下几个原因:

  1. mnist数据集图片是黑底白字,而我们平时都是白底黑字,所以要对测试图片灰度反转
    修改过后测试如图
    手写体数字识别(Python+TensorFlow)_第5张图片
    已经可以识别几个数字了,但还是很多被识别成了8。原因是自己拍的图片有很多噪点,直接输入给模型就因为噪点太多,被误认为是8。
  2. 二值化来降噪
    原图手写体数字识别(Python+TensorFlow)_第6张图片
    使用opencv二值化图像cv2.threshold(img,127,255,cv2.THRESH_BINARY)
    手写体数字识别(Python+TensorFlow)_第7张图片
    虽然还有少量噪点,但已经有很好的识别效果了。(还可以再调整阈值)
    测试结果
    手写体数字识别(Python+TensorFlow)_第8张图片
    基本上识别了所有数字。
    (测试自己写的数字时很可能因为输入图像没处理好导致识别率太低。)

我使用的版本
python3.7
tensorflow 1.13.1
运行时可能有很多warning,不影响运行结果

总结

本文只大致介绍了使用方法,神经网络的相关基础知识还要多多了解。这里使用的是基于全连接层网络结构的神经网络,对数字识别已经有了不错的效果,但使用卷积神经网络还可以提高正确率(大约99.2%),比如LeNet-5模型。

你可能感兴趣的:(python,手写体数字识别,MNIST数据集,TensorFlow,深度学习)