这篇文章是使用mmdetection的一些记录,记录对于代码、设计理念的个人理解。
使用tools.train进行训练。添加如下代码来使用debug模式:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
args = ['./configs/cascade_mask_rcnn_r101_fpn_1x.py',
'--gpus', '1',
'--work_dir', 'cascade_mask_rcnn_r101_fpn_1x'
]
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
其中,cfg是config文件,DETECTORS为全局对象,在models/registry中创建,是一个Registry对象。Registry类含_name和_module_dict属性,在一开始只将_name赋予’detector’等字符。在每个与检测器有关的类之前都有 @DETECTORS.register_module 修饰器,它可以将这个类以及其名字(_name_属性)在DETECTORS的_module_dict中。
build调用build_from_cfg,首先取出cfg建立的对象类型obj_type,使用get从注册器(Registry对象)中取出相应的类,使用inspect来判断取出的obj_type是否是类。之后使用obj_type(类)将args(就是cfg)作为参数进行实例化。
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type') # 对象的名字,比如CascadeRCNN
if mmcv.is_str(obj_type):
obj_type = registry.get(obj_type)
if obj_type is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif not inspect.isclass(obj_type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args) #返回实例化
同样的套路,build调用build_from_cfg,按照cfg中的描述进行实例化,只是cfg是dataset的cfg。
train_detector(
model,
train_dataset,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)
train_detector调用non_dist_train,在这里将model并行化,建立data_loader、optimizer、runner
##################_non_dist_train部分
# prepare data loaders
data_loaders = [
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False)
]
##################build_dataloader函数
def build_dataloader(dataset,
imgs_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
**kwargs):
shuffle = kwargs.get('shuffle', True)
if dist:
rank, world_size = get_dist_info()
if shuffle:
sampler = DistributedGroupSampler(dataset, imgs_per_gpu,
world_size, rank)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False)
batch_size = imgs_per_gpu
num_workers = workers_per_gpu
else:
sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
pin_memory=False,
**kwargs)
build_dataloader中主要是创建了两个对象sampler和collate(通过偏函数partial来创建),前者是采样器,采样出下标,后者是整理器,用于组成batch输出。之后使用pytorch自带的DataLoader就行了。sampler考虑了并行操作。collate除了支持对于Sequence,Mapping的batch构建外,更重要的是有对于DataContainer类型数据的batch操作,这是一个mmdet中创建的新类型,支持多种数据类型。
runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
cfg.log_level)
############batch_processor的定义
def batch_processor(model, data, train_mode):
losses = model(**data)
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
其中batch_processor调用model来得到loss(model的forward得到的是loss而不是网络的输出)。之后对loss进行一些小处理。
runner的初始化基本上就是model, optimizer, work_dir等的初始化。
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
# 删除部分
def run(self, data_loaders, workflow, max_epochs, **kwargs):
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
self._max_epochs = max_epochs
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
while self.epoch < max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str):
epoch_runner = getattr(self, mode)
elif callable(mode): # custom train()
epoch_runner = mode
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
return
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
workflow代表的是工作流程E.g, [(‘train’, 2), (‘val’, 1)] ;run中通过getattr获得epoch_runner,一般就是runner.train和runner.val,前者就是一般的train过程,首先self.model.train()来避免eval状态。之后就是一般的train了,里面有用到多处的call_hook
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
call_hook的作用就是一个一个的hook的fn_name这个函数作用到自身,获得一些或者改变一些信息吧。
runner控制模型的训练、验证和测试过程。
dataloader负责数据的导入。
模型中anchor生成、anchor匹配等操作均隐藏在了model中,model又分为
与anchor有关的head:anchor_head
主干:backbones,
ROI有关的head:bbox_heads,
检测器本体:detectors
损失函数:losses
与mask有关的head:mask_heads
backbone进一步基础上的特征提取module:necks
attention机制等插件:plugins
roi提取器:roiextractors
不知道是啥:shared_heads
所有与技术细节有关的部分都放在了这些model当中。这些model也会调用core中的函数。