训练自定义模型MMDetection的非侵入式配置

       刚开始用MMDetection框架,希望使用自己的模型和训练数据集,因此粗略看了官方教程和网上的一些资料。然而,这些资料大多数都需要修改MMDetection安装目录下的文件,这显然不是一种很好的做法。下面以目标检测为例,我通过自己摸索了一种非侵入式的模型和MMDetection数据配置方法,简要总结如下。

      1、安装和配置MMDetection

      通过源码安装MMDetection,具体安装方式参考官网GET STARTED — MMDetection 3.0.0 documentation。

      为方便在项目中引用,我们创建一个软链接,将mmdetection的根目录映射为当前我们项目工作目录下的一个子目录,假设为“mm/mmdetection/‘。

      这里安装的时mmdet 3.0.0版本。
      2、准备自定义数据集

      官方推荐COCO数据集,但我发现COCO的描述文件将不同的图片耦合到了一起,且当图片比较多的时候,这个文件很大,不利于调试和修改。因此我改用VOC格式,每个图片使用一个描述文件。假设数据目录为’data‘。

      假设目标有7种类型:'Car', 'Bus', 'Truck', 'Motor', 'Bike', 'Rider', 'Person'

      本地项目的目录结构如下:

├─data
│  ├─readme.txt
│  └─VOC2007
│      ├─Annotations
│      ├─ImageSets
│      │  └─Main
│      └─JPEGImages
├─mm
│  ├─config
│  ├─mmdetection
│  └─models

      3、创建数据集配置文件"mm/config/da_dataset.py"


_base_ = ['../mmdetection/configs/_base_/datasets/voc0712.py']

batch_size = 8
num_workers = 4
data_root = 'data/'
imageset_dir = 'VOC2007/ImageSets/Main/'

metainfo = {
    'classes': ('Car', 'Bus', 'Truck', 'Motor', 'Bike', 'Rider', 'Person'),
    'palette': [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192),
                (197, 226, 255), (0, 60, 100), (0, 0, 142)]
}

backend_args = None

train_pipeline = [
    dict(type='LoadImageFromFile', backend_args=backend_args),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', scale=(1000, 600), keep_ratio=True),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PackDetInputs')
]

test_pipeline = [
    dict(type='LoadImageFromFile', backend_args=backend_args),
    dict(type='Resize', scale=(1000, 600), keep_ratio=True),
    # avoid bboxes being resized
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]

train_dataloader = dict(
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=dict(type='AspectRatioBatchSampler'),
    dataset=dict(
        _delete_=True,
        type={{_base_.dataset_type}},
        metainfo=metainfo,
        data_root=data_root,
        ann_file=imageset_dir + 'train.txt',
        #data_prefix=dict(img=data_root + 'VOC2007/'),
        data_prefix=dict(sub_data_root='VOC2007/'),
        filter_cfg=dict(
            filter_empty_gt=True, min_size=32, bbox_min_size=32),
        pipeline=train_pipeline,
        backend_args=backend_args
    )
)

val_dataloader = dict(
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        data_root=data_root,
        test_mode=True,
        metainfo=metainfo,
        ann_file=imageset_dir + 'val.txt',
        #data_prefix=dict(img=data_root + 'VOC2007/'),
        data_prefix=dict(sub_data_root='VOC2007/'),
        pipeline=test_pipeline
   )
)

test_dataloader = dict(
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        data_root=data_root,
        test_mode=True,
        metainfo=metainfo,
        ann_file=imageset_dir + 'test.txt',
        data_prefix=dict(img=data_root + 'VOC2007/'),
        pipeline=test_pipeline
   )
)

# Pascal VOC2007 uses `11points` as default evaluate mode, while PASCAL
# VOC2012 defaults to use 'area'.
val_evaluator = dict(type='VOCMetric', metric='mAP', eval_mode='11points')

test_evaluator = val_evaluator


      4、创建自定义模型类,我创建一个Faster-RCNN的子类,放在mm/models/env_models.py中,代码如下:

from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from mmdet.models.detectors.faster_rcnn import FasterRCNN

@MODELS.register_module()
class EnvFaterRCNN(FasterRCNN):
    def __init__(self,
                 backbone: ConfigType,
                 rpn_head: ConfigType,
                 roi_head: ConfigType,
                 train_cfg: ConfigType,
                 test_cfg: ConfigType,
                 neck: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None) -> None:
        super().__init__(
            backbone=backbone,
            neck=neck,
            rpn_head=rpn_head,
            roi_head=roi_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg,
            data_preprocessor=data_preprocessor)

    4、创建基础运行配置文件'mm/config/da_base.py'

_base_ = [
    'da_dataset.py',
    '../mmdetection/configs/_base_/schedules/schedule_1x.py',  # 训练策略
    '../mmdetection/configs/_base_/default_runtime.py'         # 默认运行设置
]

checkpoint_config = dict(interval=10)

log_config = dict(
    interval=100
)

   5、创建特定模型相关的配置文件”mm/config/da_faster-rcnn.py“

_base_ = [
    '../mmdetection/configs/_base_/models/faster-rcnn_r50_fpn.py',
    'da_base.py'            # 默认运行设置
]

custom_imports = dict(
    imports=['env_models'],
    #imports=['mmdet.models.detectors.env_models'], 
    allow_failed_imports=False
)

model = dict(
    type="EnvFaterRCNN",
    roi_head=dict(
        # head 中的 num_classes 以匹配数据集中的类别数
        bbox_head=dict(num_classes=7)
    )
)

max_epochs = 100

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)

param_scheduler = [
    dict(
        type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
    dict(
        type='MultiStepLR',
        begin=0,
        end=max_epochs,
        by_epoch=True,
        milestones=[8, 11, 50, 70, 90, 99],
        gamma=0.1)
]

   6、训练,在安装了python的命令行环境下执行:

set PYTHONPATH=./mm/models
python mm/mmdetection/tools/train.py mm/config/da_faster-rcnn.py

   上面第一句话把自定义模型源码目录加入到python运行环境中;第二句话执行真正的训练过程。

   至此,已经全部配置完成。注意,我们并没有改动MMDetection的任何源码或目录。

    7、解决问题

    MMDetection的VOC评估数据的类别显示不对,有bug,顺手改了,位于’mmdetection/mmdet/evaluation/metrics/voc_metric.py‘第127行。改为如下形式:

            if dataset_type in ['VOC2007', 'VOC2012']:
#                dataset_name = 'voc'
                dataset_name = self.dataset_meta['classes']

你可能感兴趣的:(Python编程,目标检测,python,mmdetection,非侵入式配置)