mmclassification训练分类网络

使用mmclassification分类宠物狗

最近因为开发需要测试不同深度学习框架下模型的转换和解析,经常需要使用不同的框架训练模型。然后问题就出现了,当我尝试转换一个caffe模型结果不好这时候是我本身配置不正确导致的还是数据源头设置不正确导致的亦或者是模型转换本身出错导致的?这时候经常需要自己训练一个网络。如果使用caffe训练需要熟悉caffe的一套框架(训练和部署)特别是网络上caffe资源比较少,一些比较先进的网络自己配置实现比较麻烦,所以产生了这个需求:尽快训练、网络先进、框架主流。这时候mmclassification就成为了一个很好的选择。

训练mmclassification

mm训练比较简单,但是需要合理的配置(这里我训练的shufflenet):
网络结构的配置

# configs/_base_/models/shufflenet_v2_1x.py
model = dict(
    type='ImageClassifier',
    backbone=dict(type='ShuffleNetV2', widen_factor=1.0),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=37, # 需要修改为自己分类需要的类别数
        in_channels=1024,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5),
    ))

配置文件,主要完成:数据的指定,输入数据的变换、处理和评估指标。

# configs/_base_/datasets/pet_bs64_pil_resize.py
# dataset settings
dataset_type = 'Pet'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) # 减去均值除以方差,这里必须要修改
train_pipeline = [ # 输入数据处理的pipeline类似torchvisiom.transform操作
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224, backend='pillow'),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1), backend='pillow'),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
data = dict(
    samples_per_gpu=64,
    workers_per_gpu=0,
    train=dict(
        type=dataset_type,
        data_prefix='/tmp/pet/train',# 图像数据的前缀a.jpg
        ann_file='/tmp/pet/train.txt',# 文件的名称(a.jpg\nb.jpg ...),最后查找文件通过 /tmp/pet/train/a.jpg
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        data_prefix='/tmp/pet/val',
        ann_file='/tmp/pet/val.txt',
        pipeline=test_pipeline),
    test=dict(
        # replace `data/val` with `data/test` for standard test
        type=dataset_type,
        data_prefix='/tmp/pet/val',
        ann_file='/tmp/pet/val.txt',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='accuracy') #评价指标

学习率调度和优化器控制

# configs/_base_/schedules/pet_bs1024_linearlr_bn_nowd.py
# optimizer
optimizer = dict(
    type='SGD',
    lr=0.5,
    momentum=0.9,
    weight_decay=0.00004,
    paramwise_cfg=dict(norm_decay_mult=0))
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
    policy='poly',
    min_lr=0,
    by_epoch=False,
    warmup='constant',
    warmup_iters=5000,
)
runner = dict(type='EpochBasedRunner', max_epochs=300)

这里因为自定义了自己的数据集,需要实现自己的Dataset。本质上和pytorch的dataset,dataloader一样。

import os
import json
import numpy as np
from os.path import join, basename, dirname, isfile

from .base_dataset import BaseDataset
from .builder import DATASETS


@DATASETS.register_module()
class Pet(BaseDataset):
    """`ImageNet `_ Dataset.

    This implementation is modified from
    https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py  # noqa: E501
    """

    IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
    CLASSES = []
    def __str__(self):
        return "PetDataset"

    def load_annotations(self):
        label_file = join(dirname(self.data_prefix), 'labels.json')
        assert isfile(label_file), "{} must be exists!".format(label_file)
        with open(label_file, 'r') as f:
            label_to_num = json.load(f)
        CLASSES = list(label_to_num.keys())
        if self.ann_file is None:
            folder_to_idx = label_to_num
            samples = get_samples(
                self.data_prefix,
                folder_to_idx,
                extensions=self.IMG_EXTENSIONS)
            if len(samples) == 0:
                raise (RuntimeError('Found 0 files in subfolders of: '
                                    f'{self.data_prefix}. '
                                    'Supported extensions are: '
                                    f'{",".join(self.IMG_EXTENSIONS)}'))

            self.folder_to_idx = folder_to_idx
        elif isinstance(self.ann_file, str):
            with open(self.ann_file) as f:
                samples = [x.strip().split(' ') for x in f.readlines()]
        else:
            raise TypeError('ann_file must be a str or None')
        self.samples = samples

        data_infos = []
        for filename, gt_label in self.samples:
            info = {'img_prefix': self.data_prefix}
            info['img_info'] = {'filename': filename}
            info['gt_label'] = np.array(gt_label, dtype=np.int64)
            data_infos.append(info)
        return data_infos

dataset的实现可以参考imagenet的实现。本质上你需要重写自己的load_annotations。此函数用于通过configs/_base_/datasets/pet_bs64_pil_resize.py 指定的文件找到对应的图像文件然后调用自定义的pipeline处理图像文件生成网络训练需要的数据。这个函数中主要实现如下几个功能:

  1. CLASSES:生成成员变量,其表示你分类类别的名称列表
  2. 获取图像的完整路径用于pipeline处理数据生成网络需要的矩阵数据。这里通常不需要先解码图像未np.array数据,因为pipeline中指定的已经写好的函数用于加载图像(这里和torchvisoon的dataset+dataloader有一点点不同)。
  3. 生成标签数据的np.array数据最后生成字典data_infos

需要注意的是,实现自己的dataset后需要在dataset中的__init__.py加入对自己的dataset的导入后续训练的时候才能找到自己的dataset。

for filename, gt_label in self.samples:# 图像名称和label的数值
            info = {'img_prefix': self.data_prefix}
            info['img_info'] = {'filename': filename}
            info['gt_label'] = np.array(gt_label, dtype=np.int64)
            data_infos.append(info)

一键训练脚本

#!/bin/bash
export PYTHONPATH="${HOME}/mmcv":${PYTHONPATH}
HOME_PATH=${HOME}/mmclassification
CONFIG_FILE=${HOME_PATH}/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_pet.py
IMAGE_FILE=/tmp/pet/val/american_bulldog_34.jpg
CHECKPOINT_FILE=pet_shuffle_20210321-e9ff7d78.pth
mode=$1
if [ "${mode}"=="train" ];then
    python ${HOME_PATH}/tools/train.py ${CONFIG_FILE} 
else
    python demo/image_demo.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE}
fi 

训练:bash train_pet.sh train
验证结果:

[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 1478/1478, 238.6 task/s, elapsed: 6s, ETA:     0s2021-03-21 21:58:24,680 - mmcls - INFO - Epoch(val) [128][93]        accuracy_top-1: 76.6576, accuracy_top-5: 95.3992

你可能感兴趣的:(计算机视觉,PyTorch)