欢迎大家来到图像分类专栏,本篇基于Pytorch完成一个多类别图像分类实战。
作者 | 郭冰洋
编辑 | 言有三
1 简介
实现一个完整的图像分类任务,大致需要分为五个步骤:
1、选择开源框架
目前常用的深度学习框架主要包括tensorflow、caffe、pytorch、mxnet等;
2、构建并读取数据集
根据任务需求搜集相关图像搭建相应的数据集,常见的方式包括:网络爬虫、实地拍摄、公共数据使用等。随后根据所选开源框架读取数据集。
3、框架搭建
选择合适的网络模型、损失函数以及优化方式,以完成整体框架的搭建
4、训练并调试参数
通过训练选定合适超参数
5、测试准确率
在测试集上验证模型的最终性能
本文利用Pytorch框架,按照上述结构实现一个基本的图像分类任务,并详细阐述其中的细节及注意事项。
2 数据集
本次实战选择的数据集为Kaggle竞赛中的细胞数据集,共包含9961个训练样本,2491个测试样本,可以分为嗜曙红细胞、淋巴细胞、单核细胞、中性白细胞4个类别,图片大小为320x240。
Pytorch中封装了相应的数据读取的类函数,通过调用torch.utils.data.Datasets函数,则可以实现读取功能。
__init__()模块用来定义相关的参数,__len__()模块用来获取训练样本个数,__getitem__()模块则用来获取每张具体的图片,在读取图片时其可以通过opencv库、PIL库等进行读取,具体代码如下:
# 数据集
class dataset(data.Dataset):
# 参数预定义
此外,需要定义图像增强模块,即上述代码中的transform,通常采取的操作为翻转、剪切等,关于图像增强的具体介绍可以参考公众号前作。
【技术综述】深度学习中的数据增强方法都有哪些?
需要特别强调的是对图像进行去均值处理,很多同学不明白为何要减去均值,其主要的原因是图像作为一种平稳的数据分布,通过减去数据对应维度的统计平均值,可以消除公共部分,以凸显个体之间的特征和差异。进行去均值前后操作后的图像对比如下:
3 框架搭建
本次实战主要选取了VGG16、Resnet50、InceptionV4三个经典网络,也是对前篇文章的一个总结。
损失函数则选择交叉熵损失函数:【技术综述】一文道尽softmax loss及其变种
优化方式选择SGD、Adam优化两种:【模型训练】SGD的那些变种,真的比SGD强吗
完整代码获取方式:发送关键词“多类别分类”给公众号
4 训练及参数调试
初始学习率设置为0.01,batch size设置为8,衰减率设置为0.00001,迭代周期为15,在不同框架组合下的最佳准确率和最低loss如下图所示:
可以发现在验证集上Resnet-50+SGD+Cross Entropy的组合下取得了99%左右的准确率,相反VGG-16结果则稍微差一些。
最佳组合下的准确率走势曲线如下图所示:
5 测试
对上述模型分别在测试集上进行测试,所获得的结果如下图所示,整体精度比训练集上约下降了一个百分点:
关于代码,可以参考有三AI开源的12大深度学习开源框架使用的项目:
【完结】给新手的12大深度学习开源框架快速入门项目
总结
以上就是整个多类别图像分类实战的过程,由于时间限制,本次实战并没有对多个数据集进行训练,因此没有列出同一模型在不同数据集上的表现。
有三AI夏季划
有三AI夏季划进行中,欢迎了解并加入,系统性成长为中级CV算法工程师。
转载文章请后台联系
侵权必究
往期精选
【技术综述】你真的了解图像分类吗?
【技术综述】多标签图像分类综述
【图像分类】分类专栏正式上线啦!初入CV、AI你需要一份指南针!