基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第1张图片

ねぇ        呐

私に気付いてよ   快点注意到我吧

もう そんな事   那种事 一定

望んでも      再去奢求

しょうがないだろ  也无可奈何吧

——真野あゆみ《Bipolar emotion》(作詞:Mitsu)

本文将以mnist为原材料,用ResNet模型,实现一个正确率超过99%的手写数字分类器。


材料准备

首先我们下载MNIST数据集——可以从我的这篇文章中找到资源:

刘冬煜:关于图像识别的相关图片资源​zhuanlan.zhihu.com
基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第2张图片

接着,我们搭建一个较深层的网络结构:

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第3张图片
一个可行的网络结构示意图
	CNN.set_Input_Layer(1);						/*  0 */		//28x28
	CNN.add_Convolution_Layer(16, 5, 5, Valid);						//24x24
	CNN.add_BatchNorm_Layer(16);					/*  2 */
	CNN.add_Activation_Layer(&relu_actfunc, 16);
	CNN.add_Pooling_Layer(16, 2, 2, Max_Pooling);						//12x12
	CNN.add_Convolution_Layer(40, 5, 5, Valid);						//8x8
	CNN.add_Activation_Layer(&relu_actfunc, 40);			/*  6 */	//------.
	CNN.add_Convolution_Layer(40, 3, 3, Same);					//	|
	CNN.add_Activation_Layer(&relu_actfunc, 40);					//	|
	CNN.add_Convolution_Layer(40, 3, 3, Same);					//<-----'
		CNN.create_short_cut(6, 9);
	CNN.add_Activation_Layer(&relu_actfunc, 40);
	CNN.add_Convolution_Layer(64, 5, 5, Same);
	CNN.add_BatchNorm_Layer(64);					/* 12 */
	CNN.add_Activation_Layer(&relu_actfunc, 64);
	CNN.add_Convolution_Layer(64, 5, 5, Valid);						//4x4
	CNN.add_Activation_Layer(&relu_actfunc, 64);
	CNN.add_Spatial_Pyramid_Pooling_Layer(320, Average_Pooling, 2);				//1x1
	CNN.add_Fully_Connected_Layer(10);
	CNN.add_Softmax_Layer(10);					/* 18 */

材料准备完成,下一步将是数据的处理。

预处理

首先,我们将60000张手写数字图片读入到显存之中,通过公式

的整数RGB转化为
的浮点数作为输入。

由于对于原图每个像素点,RGB值都相等(灰度图),我们索性将三个通道合并为了一个通道作为输入。

损失函数

这是一个单标签多分类问题,我们可以使用交叉熵公式。这里我保留了反向概率项:

其中

是期望输出向量,有且仅有一个1,其余9个元素都是0;而
是实际输出向量。

通过交叉熵损失函数,我们能提高正确标签对应神经元的输出,降低错误标签对应神经元的输出。

其他超参数

由于采用了Adam法更新参数,我使用初始学习率

,一阶动量
,二阶动量

ResNet往往不使用Dropout,这个工程也不例外,因而我采用了L2正则化。初始的惩罚为:权值惩罚因子

,偏置量惩罚因子
。而后期更是加大了这个惩罚因子。

Batch size我设置得小了一点(因为我的显存比较小),只有60。

训练过程

起初误差下降得非常好:

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第4张图片
初始的误差曲线

但训练了一段时间后,明显感觉“后劲不足”:

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第5张图片
误差较高,下降缓慢

于是我果断将学习率砍掉一半,并提高了两个动量项,使用学习率

,一阶动量
,二阶动量
,权值惩罚因子
,偏置量惩罚因子
,继续训练。

于是乎误差又稳定地下降了不少:

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第6张图片
调整后的误差曲线

于是我每迭代500轮,都会将网络的学习率砍掉一半,并缓缓提升动量和惩罚因子。

大约训练到第3000轮的时候,我测试了一下网络,其中训练集结果和测试集结果如下:

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第7张图片
训练集结果

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第8张图片
测试集结果

接下来就要加入“提前终止”了——每训练200轮跑一次验证集,如果验证集的准确率有所下降,提前终止训练。

结果到第400轮的时候就结束了,结果如下:

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第9张图片
训练集结果

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第10张图片
测试集结果

最终我把错误预测的结果挑了出来:

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第11张图片
被错误预测的结果

有些被错误预测的结果甚至可以迷惑一个正常人类了……

我们甚至还可以将特征可视化出来,看图片那些成分相应每一个标签:

基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别_第12张图片
每一张图片对于每个标签的相应

可以横向比较——对于每张手写数字图片,正确标签响应位置与其他标签是有明显不同的。

这个网络的二进制文件我也将会发到专栏的GitHub上。


开学了。

今天带着学妹逛校园,好开心。

同时这也让我感受到了一些哲理。

我懂得一个人来上海举目无亲的寂寞,但我却不希望学弟学妹继承这份失落。

同样经历过实验失败的痛苦,我也不希望有人重蹈覆辙。

专栏依旧不定期更新,敬请期待。

你可能感兴趣的:(基于stm32f429的手写识别_【工程分析】基于ResNet的手写数字识别)