【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率

背景

学习深度学习的框架,积累调参经验

数据集

5类花卉图像数据,分别是向日葵、郁金香、玫瑰、蒲公英、雏菊,每类花卉在700到1000张左右,图像尺寸大小不统一,常见尺寸是320x240,数据并不干净,有些混杂的图片。
任务是利用CNN方法对其进行分类识别。

模型记录

1.基本CNN模型进行分类

  • 卷积层1:32个卷积核、大小5x5、步数1,激活函数ReLU,最大池化、步数2,输入150x150x3
  • 卷积层2:64个卷积核、大小3x3、步数1,激活函数ReLU,最大池化、步数2
  • 卷积层3:96个卷积核、大小3x3、步数1,激活函数ReLU,最大池化、步数2
  • 卷积层4:96个卷积核、大小3x3、步数1,激活函数ReLU,最大池化、步数2
  • 全连接层1:512个神经元,激活函数ReLU
  • 全连接层2:256个神经元,激活函数ReLU
  • 分类层:5个神经元,激活函数softmax

batch_size=16,epoch=50。
效果如下,并不理想,出现了过拟合,识别率在65%左右,val_loss不减反增,令人畏惧。
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第1张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第2张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第3张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第4张图片

2.数据增强

在上面的基础上,加上了图像增强处理,具体方式是,随机旋转范围10度,随机缩放0.9~1.1范围,水平和竖直偏移为范围0.2。
epoch还是50,效果如下,识别率在77%左右,loss_val下降后反弹,波动较大,训练loss持续下降,似乎可以继续下降:
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第5张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第6张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第7张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第8张图片

3.模型改进

在上面的基础上,对cnn模型进行改进,添加了权重衰减、Dropout层、批正则化层(BN)

  • 卷积层1:32个卷积核、大小5x5、步数1,激活函数ReLU,最大池化、步数2,输入150x150x3
  • 加入BN层,介于卷积层和池化层之间
  • 卷积层2:64个卷积核、大小3x3、步数1,激活函数ReLU,最大池化、步数2
  • 加入BN层,介于卷积层和池化层之间
  • 卷积层3:96个卷积核、大小3x3、步数1,激活函数ReLU,最大池化、步数2
  • 加入BN层,介于卷积层和池化层之间
  • 卷积层4:96个卷积核、大小3x3、步数1,激活函数ReLU,最大池化、步数2
  • 加入BN层,介于卷积层和池化层之间
  • 全连接层1:512个神经元,激活函数ReLU,加入权重衰减
  • 加入Dropout层,0.3
  • 全连接层2:256个神经元,激活函数ReLU,加入权重衰减
  • 加入Dropout层,0.3
  • 分类层:5个神经元,激活函数softmax

epoch150,效果如下,发现最后的识别率在83%左右。
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第9张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第10张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第11张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第12张图片

4.自调整学习率

规则是,如果在5个epoch内,val识别率还是不增加,就让lr变为原来的0.1倍

epoch50,batch_size增加,效果如下,loss趋于平坦,识别率在85%左右,说明合适的学习率很重要,但是最后lr变成了1e-10甚至以下,loss平坦,说明lr越小,收敛越没有效果。
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第13张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第14张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第15张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第16张图片

5.增加cnn网络深度

方法是在每个卷积层后面repeat同样数量的2个卷积层。效果就不上传了,只提升了0.5%~1%左右的识别率。

总结

1.卷积核尺寸不能超过图像尺寸,否则训练下来没什么效果
2.过拟合时候,采用合适的数据增强十分有效,而且对于训练效果也很有帮助,数据增强中,水平翻转比竖直翻转更有效,因为没有倒立的花朵,其他方法也可以多尝试下。
3.在预处理阶段,让X_train归一化会具有一定效果
4.在网络改进中,BN层(批正则化,让数据分布均匀)的加入和dropout层(随机失效,避免过拟合)的效果尤为明显
5.深度网络似乎并没有什么很明显的提升,甚至可能会导致梯度问题,但是深度网络的理解能力肯定要比非深度网络要好。
6.学习率的选择非常重要,越大,收敛越快,但达不到最优的点,越小,收敛越慢,可能会导致局部最优。在训练中自适应调整是一种很好的方法。
7.基本cnn:65%,数据增强:77%左右,模型改进83%左右,自适应调整学习率,85%左右,深度cnn网络,没有很明显的提升。

附个图

【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第17张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第18张图片
【调参记录】基于CNN对5类花卉植物数据一步步提升分类准确率_第19张图片
参考链接:
https://www.kaggle.com/alxmamaev/flowers-recognition(数据集来源)
https://zhuanlan.zhihu.com/p/29534841(受启发的文章)

要代码的人比较多,所以公开数据集+代码的百度网盘链接:
链接:https://pan.baidu.com/s/1-0u55gO38V0PckMZ0IyYSA
提取码:wkhi

你可能感兴趣的:(Tensorflow)