【AI】使用MMPreTrain进行图像分类

之前已经配置好了MMLab的开发环境,有兴趣的可以去看一下【AI】MMLab环境搭建
接下来我们使用MMPretrain进行图像分类来尝试一下熟肉的感觉。
MMPreTrain的官方文档地址是中文文档,Github地址是https://github.com/open-mmlab/mmpretrain
MMPreTrain是从MMClassification和MMSelfSup发展而来的,是一个预训练模型的工具包,非常适合开箱即用的操作。

1.准备工作

  1. 下载源代码,安装依赖
    可以直接使用git clone命令从github上下载官方源代码,也可以直接去github上下载压缩包
    下载完成后我们可以从命令行进入项目,要进行两步操作,一是将MMPreTrain安装为包,二是安装项目的依赖
git clone https://github.com/open-mmlab/mmpretrain.git
cd mmpretrain

# 将MMPreTrain安装为本地依赖
pip install -U openmim 
mim install -e .
# 安装MMPreTrain本身的依赖
pip install -r requirements.txt

安装MMPreTrain的有版本依赖,要注意版本对应,具体可以参考https://mmpretrain.readthedocs.io/zh-cn/latest/notes/faq.html中有关安装的描述。

  1. 准备数据源
    数据源采用之前用过的花朵分类数据即可

2.直接上手跑

mmpretrain的项目结构比较复杂,我们可以先不去管,直接去运行tools文件夹中的train.py, 运行它需要一个参数,就是在配置文件夹configs中的文件地址,我们以resnet18为例进行操作,配置文件选择resnet18_8xb32_in1k.py,将这个文件的绝对路径配置到ide中的运行参数中。
【AI】使用MMPreTrain进行图像分类_第1张图片
【AI】使用MMPreTrain进行图像分类_第2张图片

点击运行按钮,系统会报错,然后退出,不用担心,因为我们现在什么东西都没有做,报错是正常的。我们做这一步的目的是为了获取完整的配置文件:系统会在tools文件夹的work_dir目录下面生成一个包含所有配置项的文件,我们将这个文件拷贝到我们习惯的目录下面就可以使用了
【AI】使用MMPreTrain进行图像分类_第3张图片

3. 修改参数训练模型

用ide打开刚才生成的文件,对立面的参数进行修改,然后进行模型训练

  1. 修改模型分类个数
model = dict(
    backbone=dict(
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch',
        type='ResNet'),
    head=dict(
        in_channels=512,
        loss=dict(loss_weight=1.0, type='CrossEntropyLoss'),
        num_classes=102, # 修改为我们需要的分类类别
        topk=(
            1,
            5,
        ),
        type='LinearClsHead'),
    neck=dict(type='GlobalAveragePooling'),
    type='ImageClassifier')
  1. 修改dataloader
    以训练集为例进行修改
train_dataloader = dict(
    batch_size=32,
    collate_fn=dict(type='default_collate'),
    dataset=dict(
        type='CustomDataset',
        data_prefix='../demo/flower_data/train', # 修改为训练集的路径
        #data_root='path/to/data_root',
        #ann_file='meta/train.txt',      # 如果有标注文件,相对于 `data_root` 的标注文件路径
        #data_prefix='train', 
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(scale=224, type='RandomResizedCrop'),
            dict(direction='horizontal', prob=0.5, type='RandomFlip'),
            dict(type='PackInputs'),
        ],
        split='train',
        type='ImageNet'),
    num_workers=5,
    persistent_workers=True,
    pin_memory=True,
    sampler=dict(shuffle=True, type='DefaultSampler'))

同理将验证集和测试集的配置修改一下

4.修改参数运行

将运行配置文件替换为我们修改后的配置文件,然后点击运行,即可等待训练完成了。
其他参数修改后续再做介绍。

你可能感兴趣的:(人工智能,人工智能,分类,数据挖掘)