个人博客文章链接:http://www.huqj.top/article?id=168
接着上一篇所说的 BP神经网络,现在用它来实现一个手写体数字的识别程序,训练素材来自吴恩达机器学习课程,我把打包好上传到了网盘上:
1 2 |
|
训练数据一共有5000条,10个数字(0~9,为了和matlab适配,0在这里统一用10表示),每个数字各500个手写体图片,像素统一处理为20*20,其中pics中是5000张图片, data是一个.mat文件,可以直接加载到matlab中,包含两个变量X(5000x400 double矩阵)和y(5000x1 int矩阵)。
可以看到,训练数据的输入是400个像素点的灰度值,虽然图片是20x20的,但是为了处理方便将其转换成1x400的输入,可以用matlab中的reshape函数进行转换。而对于输出而言,这可以看作一个多元分类问题,一共有10种分类,所以输出可以转换成一个10维向量。定义好输入输出格式之后,再考虑下神经网络的架构,平衡性能和效率,最终选择的架构是一个25元隐含层的BP网络。另外,为了衡量最终的模型效果,我们需要从5000个数据中抽取一部分作为测试集,这里我每个数字选了10条数据作为测试数据集,不过理论上训练集和测试集的比例可以达到 7:3
利用之前编写好的BP网络训练函数和一些附加函数(sigmoid,预测函数等),最终的手写体识别训练程序如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
|
因为一开始不知道要迭代多少次,所以设置成了一个循环的结构,可以根据训练误差决定继续训练或者结束训练,然后将模型权重保存下来,下次可以接着训练。
如果想要在matlab中画出图片,可以将这一行的注释去掉:
1 |
|
然后绘出所有测试集的图片如下:
运行程序反复迭代上万次之后,在测试集上的准确率稳定在92%左右,这可能也是受模型和数据集的限制。而且这个模型只是用于黑底白字的图片,用我自己的手写数字测试效果并不太好(可能与我的图片处理有关),最高只能达到 7/10 的准确率,后续会持续考虑改进模型。
完整代码下载地址: https://download.csdn.net/download/qq_32216775/10897369