使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)

开源代码:https://github.com/xxcheng0708/Pytorch_Image_Classifier_Template​​​​​ 

使用pytorch框架搭建一个图像分类模型通常包含以下步骤:

1、数据加载DataSet,DataLoader,数据转换transforms

2、构建模型、模型训练

3、模型误差分析

下面依次来看一下上述几个步骤的实现方法:

一、数据加载、数据增强

a)、有时候torchvision.transform中提供的数据转换方法不能满足项目需要,需要自定义数据转换方法进行数据增强,以下InvertTransform类实现了__init__和__call__方法对图像的像素值进行翻转:

使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)_第1张图片

b)、数据加载可以参考pytorch数据集加载之DataSet和DataLoader。

c)、使用torchvision.transforms.Compose组合多种数据转换方法进行数据转换:

使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)_第2张图片

其中,train_transform用于训练集,val_transform用于验证集,训练集和验证集要进行相同的数据转换操作。并且在pytorch中提供的transform数据转换方法有的处理目标是Image对象,有的处理目标是tensor,并且经过处理后的数据维度变为[N,C,H,W]。 

d)、使用Dataset和DataLoader进行数据加载

  二、模型构建

 在pytorch中提供了常用的分类网络的创建接口以及预训练权重,一般情况下直接使用预训练权重来初始化backbone网络,修改网络的输出层来适配自己的数据集,仅训练网络输出层或者输出层及其之前几层就可以。模型构建方法实例如下:

使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)_第3张图片

在上述代码中,创建了MobileNet_V2网络模型,并且同时下载了预训练模型参数,同时修改了网络中的classifier模块,将输出层的维度修改为自己的数据集类别数量,同时将classifier模块的参数和模型的其他参数进行区分,模型除classifier 模块参数之外的参数使用预训练模型参数进行初始化,classifier模块参数使用随机初始化。然后在优化器中对参数进行分组训练,classifier模块的参数需要重点训练,使用较大的学习率,其他模块的参数是预训练模型参数,并且处于浅层网络中,参数仅需要稍微修改就可以,使用较小的学习率。

三、模型训练

使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)_第4张图片

使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)_第5张图片

模型训练的主要流程是,从DataLoader中分批加载数据送入模型,将模型预测结果与真实结果使用定义的loss计算损失,基于计算的损失进行梯度反向传播进行参数优化。

四、模型评估,模型准确性评估

使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)_第6张图片

五、学习曲线

使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)_第7张图片

根据迭代训练过程中保存训练集、验证集的准确率、损失值、学习率绘制学习曲线,来判断模型的学习情况,是否出现过拟合、欠拟合等。

六、误差分析

吴恩达大佬不止一次强调过误差分析的重要性,在分类模型训练过程中进行误差分析,可以清晰的看到误差来源,哪些样本容易被误识别,这些样例有什么规律,从而进行数据调整提升模型性能。在最近吴恩达大佬的一次讲座中,再次强调了误差分析的重要性,讲座中提出当你的模型性能遇到瓶颈时,是以模型为中心调整模型呢?还是以数据为中心调整数据呢?大佬的结论是调整模型几乎不会带来性能的提升,而调整数据能带来模型性能的大幅提升。下面就看一下图像分类模型的误差分析示例代码:

使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)_第8张图片

 该方法根据模型预测结果以及样本的真实标签,计算样本的正确与否,对于预测错误的样本,保存样本的真实标签,预测标签,从而能够清晰的看到那些被分类错误的样本都被误分类成了什么。

就写到这吧!有疑问欢迎随时交流。

你可能感兴趣的:(pytorch,深度学习,pytorch,图像分类)