mmdetection - anchor-based方法训练流程解析

训练流程图

最终会创建一个runner,然后调用runner.run时,实际会根据workflow中是train还是val,调用runner.py下的train和val函数。
batch_processor

def batch_processor(model, data, train_mode):
    # 这里的train_mode实际没用到
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

    return outputs

mmcv/runner/runner.py
train

def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(data_loader)
    self.call_hook('before_train_epoch')
    for i, data_batch in enumerate(data_loader):
        self._inner_iter = i
        self.call_hook('before_train_iter')
        outputs = self.batch_processor(
            self.model, data_batch, train_mode=True, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('batch_processor() must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'],
                                   outputs['num_samples'])
        self.outputs = outputs
        self.call_hook('after_train_iter')
        self._iter += 1

    self.call_hook('after_train_epoch')
    self._epoch += 1

val

def val(self, data_loader, **kwargs):
    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    self.call_hook('before_val_epoch')

    for i, data_batch in enumerate(data_loader):
        self._inner_iter = i
        self.call_hook('before_val_iter')
        with torch.no_grad():
            outputs = self.batch_processor(
                self.model, data_batch, train_mode=False, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('batch_processor() must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'],
                                   outputs['num_samples'])
        self.outputs = outputs
        self.call_hook('after_val_iter')

    self.call_hook('after_val_epoch')

validate目前只在_dist_train中有用到

训练时,实际调用:losses = model(**data),验证时,实际调用hook,运行:

with torch.no_grad():
    result = runner.model(
        return_loss=False, rescale=True, **data_gpu)

其中,TwoStageDetector和SingleStageDetector都继承了BaseDetector,在BaseDetector中,forward函数定义如下:

@auto_fp16(apply_to=('img', ))
def forward(self, img, img_meta, return_loss=True, **kwargs):
    if return_loss:
        return self.forward_train(img, img_meta, **kwargs)
    else:
        return self.forward_test(img, img_meta, **kwargs)

对于forward_test,其代码如下:

def forward_test(self, imgs, img_metas, **kwargs):
    for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
        if not isinstance(var, list):
            raise TypeError('{} must be a list, but got {}'.format(
                name, type(var)))

    num_augs = len(imgs)
    if num_augs != len(img_metas):
        raise ValueError(
            'num of augmentations ({}) != num of image meta ({})'.format(
                len(imgs), len(img_metas)))
    # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
    imgs_per_gpu = imgs[0].size(0)
    assert imgs_per_gpu == 1

    if num_augs == 1:
        return self.simple_test(imgs[0], img_metas[0], **kwargs)
    else:
        return self.aug_test(imgs, img_metas, **kwargs)

由上可以看出,子类需要写simple_test和aub_test函数。
对于一个检测模型(一阶或者二阶),在其class中,需要重写以下函数:

  • forward_train
  • simple_test
  • aug_test # 非必须

下面以retinanet举个例子,在retinanet的config文件中,model的type是RetinaNet,在mmdet/models/detectors/retinanet.py中,定义了RetinaNet,它的父类是SingleStageDetector,定义在mmdet/models/detectors/single_stage.py中,三个重要函数的代码如下:

def forward_train(self,
                  img,
                  img_metas,
                  gt_bboxes,
                  gt_labels,
                  gt_bboxes_ignore=None):
    x = self.extract_feat(img)
    outs = self.bbox_head(x)
    loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg)
    losses = self.bbox_head.loss(
        *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
    return losses

def simple_test(self, img, img_meta, rescale=False):
    x = self.extract_feat(img)
    outs = self.bbox_head(x)
    bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
    bbox_list = self.bbox_head.get_bboxes(*bbox_inputs)
    bbox_results = [
        bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
        for det_bboxes, det_labels in bbox_list
    ]
    return bbox_results[0]

def aug_test(self, imgs, img_metas, rescale=False):
    raise NotImplementedError

由上可知,计算loss的函数是在head中定义的,RetinaHead定义在mmdet/models/anchor_heads/retina_head.py中,RetinaHead三个关键函数的代码如下:

def _init_layers(self):
    self.relu = nn.ReLU(inplace=True)
    self.cls_convs = nn.ModuleList()
    self.reg_convs = nn.ModuleList()
    for i in range(self.stacked_convs):
        chn = self.in_channels if i == 0 else self.feat_channels
        self.cls_convs.append(
            ConvModule(
                chn,
                self.feat_channels,
                3,
                stride=1,
                padding=1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg))
        self.reg_convs.append(
            ConvModule(
                chn,
                self.feat_channels,
                3,
                stride=1,
                padding=1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg))
    self.retina_cls = nn.Conv2d(
        self.feat_channels,
        self.num_anchors * self.cls_out_channels,
        3,
        padding=1)
    self.retina_reg = nn.Conv2d(
        self.feat_channels, self.num_anchors * 4, 3, padding=1)

def init_weights(self):
    for m in self.cls_convs:
        normal_init(m.conv, std=0.01)
    for m in self.reg_convs:
        normal_init(m.conv, std=0.01)
    bias_cls = bias_init_with_prob(0.01)
    normal_init(self.retina_cls, std=0.01, bias=bias_cls)
    normal_init(self.retina_reg, std=0.01)

def forward_single(self, x):
    cls_feat = x
    reg_feat = x
    for cls_conv in self.cls_convs:
        cls_feat = cls_conv(cls_feat)
    for reg_conv in self.reg_convs:
        reg_feat = reg_conv(reg_feat)
    cls_score = self.retina_cls(cls_feat)
    bbox_pred = self.retina_reg(reg_feat)
    return cls_score, bbox_pred

其中,_init_layers创建head的结构,init_weights对conv的weight和bias做初始化,forward_single是经过head计算得到的分类和检测框预测结果。
forward
在具体的方法对应的head定义forward_single,最后由anchor_head.py中的forward函数进行组装。

from six.moves import map, zip
def multi_apply(func, *args, **kwargs):
    pfunc = partial(func, **kwargs) if kwargs else func # 将func的kwargs固定,返回该函数
    # 这里的*args=feats,调用forward_single对feats的元素依次跑前向
    map_results = map(pfunc, *args) # 得到[(stride1_cls,stride1_bbox,...), (stride2_cls,stride2_bbox, ...]
    return tuple(map(list, zip(*map_results)))
    # zip(*map_results) 得到 [(stride1_cls,stride2_cls,stride3_cls,...),(stride1_bbox,stride2_bbox,stride3_bbox,...)]
    # map(list, zip(*map_results)) 将(stride1_cls,stride2_cls,stride3_cls,...)变为[stride1_cls,stride2_cls,stride3_cls,...]
    # tuple之后,最后得到([stride1_cls,stride2_cls,stride3_cls,...],[stride1_bbox,stride2_bbox,stride3_bbox,...])


def forward(self, feats):
    # 输入feats是一个list,长度为stride个数,其中元素为nchw
    return multi_apply(self.forward_single, feats)

def forward_single(self, x):
    # 这里的x为feats中的某一个元素
    cls_feat = x
    reg_feat = x
    for cls_conv in self.cls_convs:
        cls_feat = cls_conv(cls_feat)
    for reg_conv in self.reg_convs:
        reg_feat = reg_conv(reg_feat)
    cls_score = self.retina_cls(cls_feat)
    bbox_pred = self.retina_reg(reg_feat)
    return cls_score, bbox_pred

loss

你可能感兴趣的:(mmdetection学习笔记)