import mmcv
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:1')
# test a single image and show the results
img = 'demo.jpg' # or img = mmcv.imread(img), which will only load it once
result = inference_detector(model, img)
# visualize the results in a new window
show_result_pyplot(model, img, result)
新建一个工程,其目录结构为:
实际上,我把mmdetection看成是第三方库(类比numpy、cv2等),所以我把下载下来的GitHub代码解压到了python的工程文件中,即
但我看有的人直接把mmdetection当做是工程文件也是没问题的,比如D2Det的官方代码就是如此:你可以发现其目录结构和mmdetection的源码结构是一样的。
该文件夹下有三个文件:
def init_detector(config, checkpoint=None, device='cuda:0'):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, str):
# 如果config是配置文件路径,则用mmcv.Config.fromfile()读取,
# 并返回类mmcv.Config()实例化后的object
# 如果config已经是类mmcv.Config()实例化后的object,则不需要其他操作。
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
config.model.pretrained = None
model = build_detector(config.model, test_cfg=config.test_cfg)
if checkpoint is not None:
map_loc = 'cpu' if device == 'cpu' else None
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
else:
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use COCO classes by default.')
model.CLASSES = get_classes('coco')
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
函数作用、输入参数、输出参数直接见注解。这里给出几个例子:
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# device 可以是'cpu'、'cuda:0'、'cuda:1'等。
def inference_detector(model, img):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
If imgs is a str, a generator will be returned, otherwise return the
detection results directly.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
# Use torchvision ops for CPU mode instead
for m in model.modules():
if isinstance(m, (RoIPool, RoIAlign)):
if not m.aligned:
# aligned=False is not implemented on CPU
# set use_torchvision on-the-fly
m.use_torchvision = True
warnings.warn('We set use_torchvision=True in CPU mode.')
# just get the actual data from DataContainer
data['img_metas'] = data['img_metas'][0].data
# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
以configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py配置文件为例,cfg.model
,cfg.train_cfg
,cfg.test_cfg
均为字典类型,分别与配置文件中的内容相对应。
DETECTORS = Registry('detector')
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
train.py
开始import 模块时就自动创建,并且每个仓库中都放了各种样的小商品,比如在detectors/__init__.py
中可以看到,仓库DETECTORS
的小商品有:。BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
# detectors/__init__.py
from .atss import ATSS
from .base import BaseDetector
from .cascade_rcnn import CascadeRCNN
from .fast_rcnn import FastRCNN
from .faster_rcnn import FasterRCNN
from .fcos import FCOS
from .fovea import FOVEA
from .fsaf import FSAF
from .gfl import GFL
from .grid_rcnn import GridRCNN
from .htc import HybridTaskCascade
from .mask_rcnn import MaskRCNN
from .mask_scoring_rcnn import MaskScoringRCNN
from .nasfcos import NASFCOS
from .point_rend import PointRend
from .reppoints_detector import RepPointsDetector
from .retinanet import RetinaNet
from .rpn import RPN
from .single_stage import SingleStageDetector
from .two_stage import TwoStageDetector
cfg
是字典类型,执行build_from_cfg()
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
fcos.py
文件中,修饰器@DETECTORS.register_module()
的作用就是将创建好的小商品FCOS
放入仓库DETECTORS
中。@DETECTORS.register_module()
class FCOS(SingleStageDetector):
"""Implementation of `FCOS `_"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)
register_module()
函数中,你可以看到它的用法,小商品的名字就是在注册时指定的,而注册时的名字就是类名FCOS
或ResNet
等。所以你要创建自己的检测器,你需要构建一个类,把这个类当做小商品,然后将其放在相应的仓库中(注册)。 def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
obj_type='FCOS'
,在def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict."""
args = cfg.copy()
obj_type = args.pop('type') # 'FCOS'
if is_str(obj_type):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)