mmdetection的github链接:https://github.com/open-mmlab/mmdetection
mmdetection的官方文档:https://mmdetection.readthedocs.io/en/latest/
mmdetection的预训练权重:https://mmdetection.readthedocs.io
在这一部分,将会介绍训练一个detector的主要的细节:
依然是非常传统的使用,我们使用Dataset
和DataLoader
来使用多线程进行数据的加载。 Dataset
将会返回一个与模型前向方法相对应的数据字典。因为在目标检测中数据可能不是相同的尺寸(例如图形尺寸,bbox的尺寸等),我们引入了一个新的DataContainer
类型来帮助收集与分类这些不同尺寸的数据。更多的细节请看here。
数据的传输通道与数据集的定义分离开,通常数据集定义如何处理annotations,数据的pipeline定义准备数据字典的所有的步骤。一个pipeline由一系列的运算组成,每一次运算都会将一个字典作为输入,而且为下一次的转换运算输出一个字典。
这些运算被分为数据读取,数据预处理,格式化,测试时增强四类
Here is an pipeline example for Faster R-CNN.
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
For each operation, we list the related dict fields that are added/updated/removed.
LoadImageFromFile
- add: img, img_shape, ori_shape
LoadAnnotations
- add: gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks, gt_semantic_seg, bbox_fields, mask_fields
LoadProposals
- add: proposals
Resize
- add: scale, scale_idx, pad_shape, scale_factor, keep_ratio - update: img, img_shape, bbox_fields, mask_fields, *seg_fields
RandomFlip
- add: flip - update: img, bbox_fields, mask_fields, *seg_fields
Pad
- add: pad_fixed_size, pad_size_divisor - update: img, pad_shape, mask_fields, seg_fields
RandomCrop
- update: img, pad_shape, gt_bboxes, gt_labels, gt_masks, *bbox_fields
Normalize
- add: img_norm_cfg - update: img
SegRescale
- update: gt_semantic_seg
PhotoMetricDistortion
- update: img
Expand
- update: img, gt_bboxes
MinIoURandomCrop
- update: img, gt_bboxes, gt_labels
Corrupt
- update: img
ToTensor
- update: specified by keys
.
ImageToTensor
- update: specified by keys
.
Transpose
- update: specified by keys
.
ToDataContainer
- update: specified by fields
.
DefaultFormatBundle
- update: img, proposals, gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks, gt_semantic_seg
Collect
- add: img_meta (the keys of img_meta is specified by meta_keys
) - remove: all other keys except for those specified by keys
MultiScaleFlipAug
在框架中,模型基本被分为四类:
我们利用上边的组件,也写了一些常规的感知器pipeline,例如单步法与两步法的感知器。
利用基本的pipeline,模型架构可以通过config文件很简单的自定义。
如果想要来利用一些新的组件,例如path aggregation FPN structure,Path Aggregation Network for Instance Segmentation,我们需要做两个工作:
create a new file in mmdet/models/necks/pafpn.py.
from ..registry import NECKS
@NECKS.register
class PAFPN(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False):
pass
def forward(self, inputs):
# implementation is ignored
pass
mmdet/models/necks/__init__.py
.from .pafpn import PAFPN
from
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5)
to
neck=dict(
type='PAFPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5)
我们将会开源更多的组件,例如backbones, necks, heads。
为了写一个新的检测pipeline,我们需要继承自BaseDetector
,在这个类中我么们定义了如下的方法:
extract_feat()
: given an image batch of shape (n, c, h, w), extract the feature map(s).forward_train()
: forward method of the training modesimple_test()
: single scale testing without augmentationaug_test()
: testing with augmentation (multi-scale, flip, etc.)TwoStageDetector是一个展示了如何使用的很简单的例子
我们采用了分离式的训练,既可以在单机上训练,也可以多机并行训练。 假设有一个8 gpu的服务器,我们将会启动8进程,每个进程运行在每个gpu上。
每个进程都保持一个孤立的模型,数据提取器,优化器 模型的参数仅仅在开始的时候同步一次。 在一次前向传播和反向传播中,在所有gpu上的梯度都会下降,优化器将会更细所有的参数,因为梯度是同时下降,模型的参数将会保持一致,
For more information, please refer to our technical report.
详细中文文档及使用细代码细节:
小哲lxz:mmdetection中文文档