mmdetection学习系列(1)——SSD网络

1. 概述

本文是本人自学mmdetection的第一篇文章,因为最近一段时间在做目标检测相关的内容,为了更好地研究领域内相关知识,特意花了不少时间熟悉mmdetection框架(https://github.com/open-mmlab/mmdetection)。边啃代码的同时边通过知乎openMMlab社区(https://www.zhihu.com/people/openmmlab)来了解其框架结构。刚开始看时由于对目标检测的整个流程还不算十分熟悉,而由于mmdetection是可以适用于多种网络的,因此其编写的代码是高度抽象化的,导致初次看时十分难以理解,曾经多次想要放弃。但是后来通过Visual Studio的debug模式逐个模块拆解了解SSD网络后,对于其他各种网络也更加容易上手了。
openMMLab设计精妙,不可能在短时间内熟悉各种模块,我学习的目的是为了尽快熟悉目标检测的训练和推理流程,因此只针对SSD的训练和推理流程作大致讲解。

2. 训练流程

2.1 训练准备

mmdetection安装好后,可以用

python tools/train.py --config /data1/lujd/object_detection/mmdetection-2.11.0/configs/ssd/ssd300_coco.py

开始训练。我是用Visual Studio中的debug模式对代码进行解析的,详细设置可以自行查找。
在train.py中代码通过导入cfg文件做了很多初始化工作,包括设置工作路径、日志记录、数据集构建、模型初始化等等。模型的初始化是在代码第158行进行的。

model = build_detector(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))

整个SSD目标检测模型包括backbone、neck和head网络,其中每一个小网络在构建时还有很多具体细节。这些细节大多会在模型训练时得到体现,因此为了更好地了解网络是如何通过输入图片与label得出loss,我们将跳过模型初始化部分,待训练过程中碰到了相关部分再回来讲述。
在train.py的最后,代码通过train_detector函数,输入模型、数据集以及配置文件对模型进行训练

train_detector(
       model,
       datasets,
       cfg,
       distributed=distributed,
       validate=(not args.no_validate),
       timestamp=timestamp,
       meta=meta)

train_detector是定义在mmdet/apis/train.py中,查看代码发现,其主要是用于构建dataloader、optimizer以及mmlab特有的runner组件,runner的介绍可以查看https://zhuanlan.zhihu.com/p/355272459,简单来说就是一个负责训练全流程的一个工作类。因此下一步我们将进入文件最后一行

runner.run(data_loaders, cfg.workflow)

对于不同的网络,将有可能调用不同的runner。在SSD的默认设置中,runner.run函数将调用mmcv/runner/epoch_based_runner.py中的EpochBasedRunner类中的run函数。
函数做了一些训练准备工作后,可以看到在最后一段

for _ in range(epochs):
                   if mode == 'train' and self.epoch >= self._max_epochs:
                       break
                   epoch_runner(data_loaders[i], **kwargs)

对于每一个epoch,都会调用一次epoch_runner函数,其中此函数是在前面的

epoch_runner = getattr(self, mode)

查看发现此处意为调用train函数,因此下一步将进入此函数。
train函数位于EpochBasedRunner类中,做的工作不多,核心代码为

self.run_iter(data_batch, train_mode=True, **kwargs)

对于dataloader中的每一个batch,将执行此函数。查看发现在run_iter函数中,核心代码为

outputs = self.model.train_step(data_batch, self.optimizer,
                                            **kwargs)

train_step函数定义在mmcv/parallel/data_parallel.py中,在最后一行开始进入到模型的前向计算

return self.module.train_step(*inputs[0], **kwargs[0])

train_step函数定义在mmdet/models/detectors/base.py中,其位于BaseDetector类中,通过调用自身的forward函数,输入初始化后的图片数据,得到返回的loss值,再将loss值进行整理后,即得到返回的output字典。
查看BaseDetector类的forward函数发现,其根据目前状态是训练还是推理将调用不同的函数,目前模型正在训练状态中,因此将调用命令

return self.forward_train(img, img_metas, **kwargs)

此函数在mmdet/models/detectors/single_stage.py中的SingleStageDetector类中调用,其中只执行3行命令

super(SingleStageDetector, self).forward_train(img, img_metas)#1
x = self.extract_feat(img)#2
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                              gt_labels, gt_bboxes_ignore)#3

首先将调用父类的forward_train函数,将图片的大小信息添加到img_metas中。img_metas是一个字典列表,包含了每张图片的各种元信息,如大小、路径、初始化操作等。
之后x = self.extract_feat(img)是调用模型的backbone及neck(如果有),将图片转化为feature map。之后将结合grountruth的bounding box和label计算得出损失。文章的下一步将重点讲述这个部分。

2.2 Backbone

SSD默认的BackBone是VGG16,其前向推理定义在mmdet/models/backbones/ssd_vgg.py的forward函数中。将输入x不断通过self.features和self.extra层进行计算获取feature map。其中self.features包含的Sequences即为VGG16模块,对于此模块及默认输入的3*300*300的图片,进行3次下采样后默认输出的是[512,38,38]和[1024,19,19]的feature map。
self.extra模块由7个Conv模块组成,其中的第1、3、5、7个模块组成输出结果,feature map尺寸分别为[512,10,10],[256,5,5],[256,3,3],[256,1,1]。不同的feature map尺寸用于检测不同尺度的目标。

2.3 Head

接下来是模型Head模块的前向进行

losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                              gt_labels, gt_bboxes_ignore)#3

SSD默认的Head模块是BaseDenseHead,位于mmdet/models/dense_heads/base_dense_head.py,fowrard_train的定义为

def forward_train(self,
                      x,
                      img_metas,
                      gt_bboxes,
                      gt_labels=None,
                      gt_bboxes_ignore=None,
                      proposal_cfg=None,
                      **kwargs):
        outs = self(x)
        if gt_labels is None:
            loss_inputs = outs + (gt_bboxes, img_metas)
        else:
            loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
        losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        if proposal_cfg is None:
            return losses
        else:
            proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
            return losses, proposal_list

函数主要工作为,对经过backbone后的的6个feature map进行前向传播,得出outs,然后再通过outs,ground truth box,ground truth label在self.loss函数中得出损失函数值。如果是two-stage算法的话,还会通过get_bboxes,根据输出的outs张量得出feature map对应的proposal_list。two-stage算法将在下一个系列中写一个关于faster-rcnn和cascade-rcnn。

2.3.1 前向传播

self(x)调用的是SSDHead类的forward函数,定义在mmdet/models/dense_heads/ssd_head.py中,

def forward(self, feats):
        cls_scores = []
        bbox_preds = []
        for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
                                            self.cls_convs):
            cls_scores.append(cls_conv(feat))
            bbox_preds.append(reg_conv(feat))
        return cls_scores, bbox_preds

输入参数feats就是图片输入backbone后得出的feature map列表,每一个feature的尺寸为

for i in range(6):
    print(feats[i].shape[1:])
    
torch.Size([512, 38, 38])
torch.Size([1024, 19, 19])
torch.Size([512, 10, 10])
torch.Size([256, 5, 5])
torch.Size([256, 3, 3])
torch.Size([256, 1, 1])

self.reg_convs和self.cls_convs分别是Head类初始化时定义好的6个Module。分别对每一个feature map进行运算,得出每一个cell的方框坐标和类别坐标。

print(self.reg_convs)
ModuleList(
  (0): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (2): Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
print(self.cls_convs)
ModuleList(
  (0): Conv2d(512, 324, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): Conv2d(1024, 486, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (2): Conv2d(512, 486, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): Conv2d(256, 486, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): Conv2d(256, 324, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): Conv2d(256, 324, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

在reg_convs中,每一个方框对应4个坐标,而由于每一个feature map中的cell是对应4个或者6个不同尺寸的anchors,因此输出的维度为16或者24。对于cls_convs模块,在coco数据集上每一个anchor预测的是80+1(背景)类,因此输出通道为324或者486。

2.3.2 损失

SSDHead的损失函数定义为

def loss(self,
             cls_scores,
             bbox_preds,
             gt_bboxes,
             gt_labels,
             img_metas,
             gt_bboxes_ignore=None):
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        assert len(featmap_sizes) == self.anchor_generator.num_levels

        device = cls_scores[0].device

        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, img_metas, device=device)
        cls_reg_targets = self.get_targets(
            anchor_list,
            valid_flag_list,
            gt_bboxes,
            img_metas,
            gt_bboxes_ignore_list=gt_bboxes_ignore,
            gt_labels_list=gt_labels,
            label_channels=1,
            unmap_outputs=False)
        if cls_reg_targets is None:
            return None
        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
         num_total_pos, num_total_neg) = cls_reg_targets

        num_images = len(img_metas)
        all_cls_scores = torch.cat([
            s.permute(0, 2, 3, 1).reshape(
                num_images, -1, self.cls_out_channels) for s in cls_scores
        ], 1)
        all_labels = torch.cat(labels_list, -1).view(num_images, -1)
        all_label_weights = torch.cat(label_weights_list,
                                      -1).view(num_images, -1)
        all_bbox_preds = torch.cat([
            b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
            for b in bbox_preds
        ], -2)
        all_bbox_targets = torch.cat(bbox_targets_list,
                                     -2).view(num_images, -1, 4)
        all_bbox_weights = torch.cat(bbox_weights_list,
                                     -2).view(num_images, -1, 4)

        # concat all level anchors to a single tensor
        all_anchors = []
        for i in range(num_images):
            all_anchors.append(torch.cat(anchor_list[i]))

        # check NaN and Inf
        assert torch.isfinite(all_cls_scores).all().item(), \
            'classification scores become infinite or NaN!'
        assert torch.isfinite(all_bbox_preds).all().item(), \
            'bbox predications become infinite or NaN!'

        losses_cls, losses_bbox = multi_apply(
            self.loss_single,
            all_cls_scores,
            all_bbox_preds,
            all_anchors,
            all_labels,
            all_label_weights,
            all_bbox_targets,
            all_bbox_weights,
            num_total_samples=num_total_pos)
        return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)

函数输入参数是预测的方框坐标和label分数,真实的方框坐标和label。利用self.get_anchors函数生成每个feature map,每个cell的anchor坐标,再用self.get_targets函数,计算方框坐标和label与anchor的关系,生成target。最后再通过self.loss_single函数计算target与预测方框和label的损失值。下边将会详细讲解loss的计算过程。

2.3.2.1 anchor生成

初次通过mmdetection学习anchor生成模块的时候比较难懂,因为有各个模块互相嵌套,因此这里会写得比较详细。
self.get_anchor定义在mmdet/models/dense_heads/anchor_head.py中

def get_anchors(self, featmap_sizes, img_metas, device='cuda'):

		num_imgs = len(img_metas)

        # since feature map sizes of all images are the same, we only compute
        # anchors for one time
        multi_level_anchors = self.anchor_generator.grid_anchors(
            featmap_sizes, device)
        anchor_list = [multi_level_anchors for _ in range(num_imgs)]

        # for each image, we compute valid flags of multi level anchors
        valid_flag_list = []
        for img_id, img_meta in enumerate(img_metas):
            multi_level_flags = self.anchor_generator.valid_flags(
                featmap_sizes, img_meta['pad_shape'], device)
            valid_flag_list.append(multi_level_flags)

        return anchor_list, valid_flag_list

输入的是feature map的尺寸列表,这里一共有6个不同尺寸的feature map。利用self.anchor_generator.grid_anchors生成每个feature map对应的anchors列表,然后再用anchor_generator.valid_flags计算每一个anchor是否可用,因为有些在feature map边缘的anchor其尺寸超出了图片范围,这些anchor的话将不被采用。
anchor_generator.grid_anchors定义在mmdet/core/anchor/anchor_generator.py

def grid_anchors(self, featmap_sizes, device='cuda'):
		assert self.num_levels == len(featmap_sizes)
        multi_level_anchors = []
        for i in range(self.num_levels):
            anchors = self.single_level_grid_anchors(
                self.base_anchors[i].to(device),
                featmap_sizes[i],
                self.strides[i],
                device=device)
            multi_level_anchors.append(anchors)
        return multi_level_anchors

函数输入的是6个元素的feature map列表,输出的是这个6个feature map对应的anchor列表,每一个feature map的anchor是通过self.single_level_grid_anchors函数生成的。
函数定义为

def single_level_grid_anchors(self,
                                  base_anchors,
                                  featmap_size,
                                  stride=(16, 16),
                                  device='cuda'):
                                  
		feat_h, feat_w = featmap_size
        shift_x = torch.arange(0, feat_w, device=device) * stride[0]
        shift_y = torch.arange(0, feat_h, device=device) * stride[1]

        shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
        shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
        shifts = shifts.type_as(base_anchors)
        # first feat_w elements correspond to the first row of shifts
        # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
        # shifted anchors (K, A, 4), reshape to (K*A, 4)

        all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
        all_anchors = all_anchors.view(-1, 4)
        # first A rows correspond to A anchors of (0, 0) in feature map,
        # then (0, 1), (0, 2), ...
        return all_anchors

函数输入的是base_anchors,feature map尺寸以及stride步长。base_anchor是在此类初始化时生成的坐标为(0,0)的cell的anchor尺寸。通过feature map的尺寸以及步长,生成了一个对应原图的位移量列表,然后将base_anchor的尺寸加上这些位移量元素,即得到所有anchor的列表。下边将会讲述一下base_anchors的生成。

base_anchors在AnchorGenerator初始化时就已经调用gen_base_anchors函数生成了。函数定义为

def gen_base_anchors(self):
		multi_level_base_anchors = []
        for i, base_size in enumerate(self.base_sizes):
            center = None
            if self.centers is not None:
                center = self.centers[i]
            multi_level_base_anchors.append(
                self.gen_single_level_base_anchors(
                    base_size,
                    scales=self.scales,
                    ratios=self.ratios,
                    center=center))
        return multi_level_base_anchors

可以看到函数主要是利用self.gen_single_level_base_anchors生成base anchor的,因此需要查看一下这个函数。

def gen_single_level_base_anchors(self,
                                      base_size,
                                      scales,
                                      ratios,
                                      center=None):
		w = base_size
        h = base_size
        if center is None:
            x_center = self.center_offset * (w - 1)
            y_center = self.center_offset * (h - 1)
        else:
            x_center, y_center = center

        h_ratios = torch.sqrt(ratios)
        w_ratios = 1 / h_ratios
        if self.scale_major:
            ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
            hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
        else:
            ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
            hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)

        # use float anchor and the anchor's center is aligned with the
        # pixel center
        base_anchors = [
            x_center - 0.5 * (ws - 1), y_center - 0.5 * (hs - 1),
            x_center + 0.5 * (ws - 1), y_center + 0.5 * (hs - 1)
        ]
        base_anchors = torch.stack(base_anchors, dim=-1).round()

        return base_anchors

函数的输入参数为base_size,scale和ratios。这几个值都是在SSDAnchorGenerator初始化时通过读取配置文件并经过一些计算获取的。
函数的流程比较简单,首先读取center值作为anchor的坐标,然后通过ratios值得到anchor的长和宽,再乘以scale得到最终的长宽值,再转化为左上和右下的坐标就好了。
下边解析一下base_size,scale和ratios值是如何生成的,此3值在SSDAnchorGenerator在初始化函数中生成。

def __init__(self,
                 strides,
                 ratios,
                 basesize_ratio_range,
                 input_size=300,
                 scale_major=True):
        assert len(strides) == len(ratios)
        assert mmcv.is_tuple_of(basesize_ratio_range, float)

        self.strides = [_pair(stride) for stride in strides]
        self.input_size = input_size
        self.centers = [(stride[0] / 2., stride[1] / 2.)
                        for stride in self.strides]
        self.basesize_ratio_range = basesize_ratio_range

        # calculate anchor ratios and sizes
        min_ratio, max_ratio = basesize_ratio_range
        min_ratio = int(min_ratio * 100)
        max_ratio = int(max_ratio * 100)
        step = int(np.floor(max_ratio - min_ratio) / (self.num_levels - 2))
        min_sizes = []
        max_sizes = []
        for ratio in range(int(min_ratio), int(max_ratio) + 1, step):
            min_sizes.append(int(self.input_size * ratio / 100))
            max_sizes.append(int(self.input_size * (ratio + step) / 100))
        if self.input_size == 300:
            if basesize_ratio_range[0] == 0.15:  # SSD300 COCO
                min_sizes.insert(0, int(self.input_size * 7 / 100))
                max_sizes.insert(0, int(self.input_size * 15 / 100))
            elif basesize_ratio_range[0] == 0.2:  # SSD300 VOC
                min_sizes.insert(0, int(self.input_size * 10 / 100))
                max_sizes.insert(0, int(self.input_size * 20 / 100))
            else:
                raise ValueError(
                    'basesize_ratio_range[0] should be either 0.15'
                    'or 0.2 when input_size is 300, got '
                    f'{basesize_ratio_range[0]}.')
        elif self.input_size == 512:
            if basesize_ratio_range[0] == 0.1:  # SSD512 COCO
                min_sizes.insert(0, int(self.input_size * 4 / 100))
                max_sizes.insert(0, int(self.input_size * 10 / 100))
            elif basesize_ratio_range[0] == 0.15:  # SSD512 VOC
                min_sizes.insert(0, int(self.input_size * 7 / 100))
                max_sizes.insert(0, int(self.input_size * 15 / 100))
            else:
                raise ValueError('basesize_ratio_range[0] should be either 0.1'
                                 'or 0.15 when input_size is 512, got'
                                 f' {basesize_ratio_range[0]}.')
        else:
            raise ValueError('Only support 300 or 512 in SSDAnchorGenerator'
                             f', got {self.input_size}.')

        anchor_ratios = []
        anchor_scales = []
        for k in range(len(self.strides)):
            scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
            anchor_ratio = [1.]
            for r in ratios[k]:
                anchor_ratio += [1 / r, r]  # 4 or 6 ratio
            anchor_ratios.append(torch.Tensor(anchor_ratio))
            anchor_scales.append(torch.Tensor(scales))

        self.base_sizes = min_sizes
        self.scales = anchor_scales
        self.ratios = anchor_ratios
        self.scale_major = scale_major
        self.center_offset = 0
        self.base_anchors = self.gen_base_anchors()

函数输入为stride,ratios和basesize_ratio_range,此3值都是在配置文件中读取的,分别为

basesize_ratio_range=(0.15, 0.9),
strides=[8, 16, 32, 64, 100, 300],
ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]])

base_sizes的赋值是由min_sizes而来的,查看代码得知,将(0.15,0.9)中间平均多插5个值得出列表(0.15,0.3,0.45,0.6,0.75,0.9),再分别与300相乘得出min_size和max_size的列表。
在最后的for循环中可以看到ratio和scale的赋值情况,每一层的anchor其scale都是1和 m a x _ s i z e / m i n _ s i z e \sqrt{max\_size/min\_size} max_size/min_size ,而anchor_ratio分别为读取到的ratios及其倒数。
由此,SSD网络中的Anchor生成部分基本讲述完毕。

2.3.2.2 target生成

讲述完Anchor的生成机制后,下一步需要看看如何利用Anchor和标签生成可计算的target。
get_target函数定义在mmdet/models/dense_heads/anchor_head.py

def get_targets(self,
                    anchor_list,
                    valid_flag_list,
                    gt_bboxes_list,
                    img_metas,
                    gt_bboxes_ignore_list=None,
                    gt_labels_list=None,
                    label_channels=1,
                    unmap_outputs=True,
                    return_sampling_results=False):
                    num_imgs = len(img_metas)
        assert len(anchor_list) == len(valid_flag_list) == num_imgs

        # anchor number of multi levels
        num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
        # concat all level anchors to a single tensor
        concat_anchor_list = []
        concat_valid_flag_list = []
        for i in range(num_imgs):
            assert len(anchor_list[i]) == len(valid_flag_list[i])
            concat_anchor_list.append(torch.cat(anchor_list[i]))
            concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))

        # compute targets for each image
        if gt_bboxes_ignore_list is None:
            gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
        if gt_labels_list is None:
            gt_labels_list = [None for _ in range(num_imgs)]
        results = multi_apply(
            self._get_targets_single,
            concat_anchor_list,
            concat_valid_flag_list,
            gt_bboxes_list,
            gt_bboxes_ignore_list,
            gt_labels_list,
            img_metas,
            label_channels=label_channels,
            unmap_outputs=unmap_outputs)
        (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
         pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
        rest_results = list(results[7:])  # user-added return values
        # no valid anchors
        if any([labels is None for labels in all_labels]):
            return None
        # sampled anchors of all images
        num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
        num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
        # split targets to a list w.r.t. multiple levels
        labels_list = images_to_levels(all_labels, num_level_anchors)
        label_weights_list = images_to_levels(all_label_weights,
                                              num_level_anchors)
        bbox_targets_list = images_to_levels(all_bbox_targets,
                                             num_level_anchors)
        bbox_weights_list = images_to_levels(all_bbox_weights,
                                             num_level_anchors)
        res = (labels_list, label_weights_list, bbox_targets_list,
               bbox_weights_list, num_total_pos, num_total_neg)
        if return_sampling_results:
            res = res + (sampling_results_list, )
        for i, r in enumerate(rest_results):  # user-added return values
            rest_results[i] = images_to_levels(r, num_level_anchors)

        return res + tuple(rest_results)

其中的输入参数为:

  • anchor_list 每张图片的所有anchor坐标。
  • valid_flag_list 与anchor一一对应的flag,标记该anchor是否可用
  • gt_bboxes_list 方框ground truth
  • img_metas 图片的元素
  • gt_bboxes_ignore_list 忽略的方框标签
  • gt_labels_list 标签ground truth

首先,将每一张图片的anchor_list和valid_flag_list平铺成一个二维张量。每一个anchor_list是一个有6个元素的列表,对应6个level的feature map。每个元素是一个(num_anchor,4)的张量,分别为每个level的anchor数为: 38 ∗ 38 ∗ 4 = 5776 38*38*4=5776 38384=5776 19 ∗ 19 ∗ 6 = 2166 19*19*6=2166 19196=2166 10 ∗ 10 ∗ 6 = 600 10*10*6=600 10106=600 5 ∗ 5 ∗ 6 = 150 5*5*6=150 556=150 3 ∗ 3 ∗ 4 = 36 3*3*4=36 334=36 1 ∗ 1 ∗ 4 = 4 1*1*4=4 114=4,总共5776个anchor。
之后用self._get_targets_single函数,计算得到编码好的label,box和正负例的一个列表。因此函数的关键就在于这个_get_targets_single函数,其定义为

def _get_targets_single(self,
                            flat_anchors,
                            valid_flags,
                            gt_bboxes,
                            gt_bboxes_ignore,
                            gt_labels,
                            img_meta,
                            label_channels=1,
                            unmap_outputs=True):
        inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
                                           img_meta['img_shape'][:2],
                                           self.train_cfg.allowed_border)
        if not inside_flags.any():
            return (None, ) * 7
        # assign gt and sample anchors
        anchors = flat_anchors[inside_flags, :]

        assign_result = self.assigner.assign(
            anchors, gt_bboxes, gt_bboxes_ignore,
            None if self.sampling else gt_labels)
        sampling_result = self.sampler.sample(assign_result, anchors,
                                              gt_bboxes)

        num_valid_anchors = anchors.shape[0]
        bbox_targets = torch.zeros_like(anchors)
        bbox_weights = torch.zeros_like(anchors)
        labels = anchors.new_full((num_valid_anchors, ),
                                  self.num_classes,
                                  dtype=torch.long)
        label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)

        pos_inds = sampling_result.pos_inds
        neg_inds = sampling_result.neg_inds
        if len(pos_inds) > 0:
            if not self.reg_decoded_bbox:
                pos_bbox_targets = self.bbox_coder.encode(
                    sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
            else:
                pos_bbox_targets = sampling_result.pos_gt_bboxes
            bbox_targets[pos_inds, :] = pos_bbox_targets
            bbox_weights[pos_inds, :] = 1.0
            if gt_labels is None:
                # Only rpn gives gt_labels as None
                # Foreground is the first class since v2.5.0
                labels[pos_inds] = 0
            else:
                labels[pos_inds] = gt_labels[
                    sampling_result.pos_assigned_gt_inds]
            if self.train_cfg.pos_weight <= 0:
                label_weights[pos_inds] = 1.0
            else:
                label_weights[pos_inds] = self.train_cfg.pos_weight
        if len(neg_inds) > 0:
            label_weights[neg_inds] = 1.0

        # map up to original set of anchors
        if unmap_outputs:
            num_total_anchors = flat_anchors.size(0)
            labels = unmap(
                labels, num_total_anchors, inside_flags,
                fill=self.num_classes)  # fill bg label
            label_weights = unmap(label_weights, num_total_anchors,
                                  inside_flags)
            bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
            bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)

        return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
                neg_inds, sampling_result)

函数首先先通过valid_flags筛选掉不合适的anchor,再用self.assigner.assign函数,将ground truth box的坐标与anchor匹配起来,返回匹配的结果assign_result。下方的代码跟这个结果有很多联系,因此我们先看一下这个self.assigner.assign函数。

def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
        overlaps = self.iou_calculator(gt_bboxes, bboxes)
        assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
        return assign_result

函数首先计算ground truth方框与每个AnchorBox的IOU,返回overlaps。我这里测试时一共有2个ground truth方框,8732个Anchor,所以返回一个(2,8732)的张量。之后通过self.assign_wrt_overlaps函数返回assign_result。

def assign_wrt_overlaps(self, overlaps, gt_labels=None):
        num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)

        # 1. assign -1 by default
        assigned_gt_inds = overlaps.new_full((num_bboxes, ),
                                             -1,
                                             dtype=torch.long)

        if num_gts == 0 or num_bboxes == 0:
            # No ground truth or boxes, return empty assignment
            max_overlaps = overlaps.new_zeros((num_bboxes, ))
            if num_gts == 0:
                # No truth, assign everything to background
                assigned_gt_inds[:] = 0
            if gt_labels is None:
                assigned_labels = None
            else:
                assigned_labels = overlaps.new_full((num_bboxes, ),
                                                    -1,
                                                    dtype=torch.long)
            return AssignResult(
                num_gts,
                assigned_gt_inds,
                max_overlaps,
                labels=assigned_labels)

        # for each anchor, which gt best overlaps with it
        # for each anchor, the max iou of all gts
        max_overlaps, argmax_overlaps = overlaps.max(dim=0)
        # for each gt, which anchor best overlaps with it
        # for each gt, the max iou of all proposals
        gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)

        # 2. assign negative: below
        # the negative inds are set to be 0
        if isinstance(self.neg_iou_thr, float):
            assigned_gt_inds[(max_overlaps >= 0)
                             & (max_overlaps < self.neg_iou_thr)] = 0
        elif isinstance(self.neg_iou_thr, tuple):
            assert len(self.neg_iou_thr) == 2
            assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
                             & (max_overlaps < self.neg_iou_thr[1])] = 0

        # 3. assign positive: above positive IoU threshold
        pos_inds = max_overlaps >= self.pos_iou_thr
        assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1

        if self.match_low_quality:
            # Low-quality matching will overwrite the assigned_gt_inds assigned
            # in Step 3. Thus, the assigned gt might not be the best one for
            # prediction.
            # For example, if bbox A has 0.9 and 0.8 iou with GT bbox 1 & 2,
            # bbox 1 will be assigned as the best target for bbox A in step 3.
            # However, if GT bbox 2's gt_argmax_overlaps = A, bbox A's
            # assigned_gt_inds will be overwritten to be bbox B.
            # This might be the reason that it is not used in ROI Heads.
            for i in range(num_gts):
                if gt_max_overlaps[i] >= self.min_pos_iou:
                    if self.gt_max_assign_all:
                        max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
                        assigned_gt_inds[max_iou_inds] = i + 1
                    else:
                        assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1

        if gt_labels is not None:
            assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
            pos_inds = torch.nonzero(
                assigned_gt_inds > 0, as_tuple=False).squeeze()
            if pos_inds.numel() > 0:
                assigned_labels[pos_inds] = gt_labels[
                    assigned_gt_inds[pos_inds] - 1]
        else:
            assigned_labels = None

        return AssignResult(
            num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)

函数接收两个参数,ground truth label以及ground truth box与AnchorBox的IOU张量。
首先创建默认值为-1的assigned_gt_inds变量,维度是8732,用于分配每一个AnchorBox所属的gtbox下标。
然后对overlap张量的第0维和第一维求最大值得到max_overlaps, argmax_overlaps,每一个Anchor与gtbox的最大IOU以及对应的box下标;gt_max_overlaps, gt_argmax_overlaps,每一个ground truth box与Anchor的最大IOU,以及对应Anchor下标。
然后根据负样本阈值,将IOU小于该阈值的Anchor的gtbox下标设为0,表示其为背景。根据正样本阈值,分配IOU大于该阈值的Anchor的gtbox下标。
接着为了防止有些gtbox由于与每个anchor的IOU都比较低没有在以上的策略匹配到,对于每一个gtbox都分配一个与之IOU最大的anchor作为正样本,使每一个gtbox都至少能够分配上一个anchor。
最后创建一个默认值为-1,长度为num AnchorBox的变量assigned_labels,用于分配正样本的label值,将上一步给assigned_gt_inds分配了gtbox的Anchor分配好gtlabel。最后将结果组合成AssignResult类返回。

现在我们可以返回_get_targets_single函数了,上边说了一大堆都是关于self.assigner.assign函数,这个是描述AnchorBox与gtBox分配流程的一个函数,对于目标检测来说是一个核心。
下一阶段是对分配了正样本的Anchor进行编码,编码完成后才能计算loss。编码的函数在self.bbox_coder.encode,编码的规则比较常规
d x = ( g x − p x ) / p x d y = ( g y − p y ) / p y d w = l o g ( g w / p w ) d h = l o g ( g h / d h ) dx =(gx - px)/px \quad dy = (gy-py)/py \\ dw =log(gw/pw) \quad dh=log(gh/dh) dx=(gxpx)/pxdy=(gypy)/pydw=log(gw/pw)dh=log(gh/dh)
最后代码将返回编码好的方框坐标以及正负样本的下标等。
之后返回到get_targets函数中,最后要做的事情就比较简单了,将之前flatten过的label和box的列表通过images_to_levels函数重新变成6个对应不同feature map的列表,将编码好的label、gtbox以及正负样本数等返回。
最后我们返回到loss函数,在分配好每个Anchor的标签值后,剩下要做的事情已经比较简单,就是将每张图片的所有Anchor的预测分数all_cls_score(8732,81),预测方框坐标all_bbox_preds(8732,4),所有Anchor坐标all_anchors(8732,),所有Anchor分配gtlabel坐标all_labels(8732),所有Anchor分配的gtbox坐标all_bbox_targets(8732, 4)送到self.loss_single函数中计算得出损失。

loss_single函数定义如下

def loss_single(self, cls_score, bbox_pred, anchor, labels, label_weights,
                    bbox_targets, bbox_weights, num_total_samples):
        loss_cls_all = F.cross_entropy(
            cls_score, labels, reduction='none') * label_weights
        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        pos_inds = ((labels >= 0) &
                    (labels < self.num_classes)).nonzero().reshape(-1)
        neg_inds = (labels == self.num_classes).nonzero().view(-1)

        num_pos_samples = pos_inds.size(0)
        num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples
        if num_neg_samples > neg_inds.size(0):
            num_neg_samples = neg_inds.size(0)
        topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
        loss_cls_pos = loss_cls_all[pos_inds].sum()
        loss_cls_neg = topk_loss_cls_neg.sum()
        loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples

        if self.reg_decoded_bbox:
            # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
            # is applied directly on the decoded bounding boxes, it
            # decodes the already encoded coordinates to absolute format.
            bbox_pred = self.bbox_coder.decode(anchor, bbox_pred)

        loss_bbox = smooth_l1_loss(
            bbox_pred,
            bbox_targets,
            bbox_weights,
            beta=self.train_cfg.smoothl1_beta,
            avg_factor=num_total_samples)
        return loss_cls[None], loss_bbox

首先计算预测的label和gtlabel的cross_entropy loss。然后根据正样本数量以及从配置中读取的neg_pos_ratio得出负样本数量,我在运行时正样本数量为35,读取的neg_pos_ratio为3,所以负样本数量为105。然后在所有负样本的cross_entropy loss中挑选最大的105个出来,与正样本的loss相加作为最后的分类损失。
最后通过smooth_l1_loss计算正样本的box和btbox的损失,注意由于bbox_weights的存在,计算loss的时候会筛选掉负样本的anchor。
至此,loss计算完毕,SSD的训练过程的前向传播也讲述完毕。下边讲一下模型的推理部分。

3. 推理流程

模型的推理流程可以在demo文件夹中运行

py image_demo.py  demo.jpg  ../configs/ssd/ssd300_coco.py  ../checkpoints/ssd300_coco_20200307-a92d2092.pth

推理的流程主要是在mmdet/apis/inference.py中的inference_detector函数中完成的。在用test_pipeline对图片进行初始化后即送进model中返回result。
下一步调用的函数是位于mmdet/models/detectors/base.py的BaseDetector类中的forward_test函数,定义如下

def forward_test(self, imgs, img_metas, **kwargs):
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError(f'{name} must be a list, but got {type(var)}')

        num_augs = len(imgs)
        if num_augs != len(img_metas):
            raise ValueError(f'num of augmentations ({len(imgs)}) '
                             f'!= num of image meta ({len(img_metas)})')

        # NOTE the batched image size information may be useful, e.g.
        # in DETR, this is needed for the construction of masks, which is
        # then used for the transformer_head.
        for img, img_meta in zip(imgs, img_metas):
            batch_size = len(img_meta)
            for img_id in range(batch_size):
                img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:])

        if num_augs == 1:
            # proposals (List[List[Tensor]]): the outer list indicates
            # test-time augs (multiscale, flip, etc.) and the inner list
            # indicates images in a batch.
            # The Tensor should have a shape Px4, where P is the number of
            # proposals.
            if 'proposals' in kwargs:
                kwargs['proposals'] = kwargs['proposals'][0]
            return self.simple_test(imgs[0], img_metas[0], **kwargs)

函数主要的推理还是在最后一行self.simple_test函数中,定义在mmdet/models/detectors/single_stage.py SingleStageDetector中

    def simple_test(self, img, img_metas, rescale=False):
        """Test function without test time augmentation.

        Args:
            imgs (list[torch.Tensor]): List of multiple images
            img_metas (list[dict]): List of image information.
            rescale (bool, optional): Whether to rescale the results.
                Defaults to False.

        Returns:
            list[list[np.ndarray]]: BBox results of each image and classes.
                The outer list corresponds to each image. The inner list
                corresponds to each class.
        """
        x = self.extract_feat(img)
        outs = self.bbox_head(x)
        # get origin input shape to support onnx dynamic shape
        if torch.onnx.is_in_onnx_export():
            # get shape as tensor
            img_shape = torch._shape_as_tensor(img)[2:]
            img_metas[0]['img_shape_for_onnx'] = img_shape
        bbox_list = self.bbox_head.get_bboxes(
            *outs, img_metas, rescale=rescale)
        # skip post-processing when exporting to ONNX
        if torch.onnx.is_in_onnx_export():
            return bbox_list

        bbox_results = [
            bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
            for det_bboxes, det_labels in bbox_list
        ]
        return bbox_results

首先图片经过模型的backbone和head模块进行前向推理,得到所以Anchor的坐标和标签预测结果。之后用self.head.get_bboxes函数得到预测的box坐标

def get_bboxes(self,
                   cls_scores,
                   bbox_preds,
                   img_metas,
                   cfg=None,
                   rescale=False,
                   with_nms=True):
        assert len(cls_scores) == len(bbox_preds)
        num_levels = len(cls_scores)

        device = cls_scores[0].device
        featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
        mlvl_anchors = self.anchor_generator.grid_anchors(
            featmap_sizes, device=device)

        cls_score_list = [cls_scores[i].detach() for i in range(num_levels)]
        bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)]

        if torch.onnx.is_in_onnx_export():
            assert len(
                img_metas
            ) == 1, 'Only support one input image while in exporting to ONNX'
            img_shapes = img_metas[0]['img_shape_for_onnx']
        else:
            img_shapes = [
                img_metas[i]['img_shape']
                for i in range(cls_scores[0].shape[0])
            ]
        scale_factors = [
            img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0])
        ]

        if with_nms:
            # some heads don't support with_nms argument
            result_list = self._get_bboxes(cls_score_list, bbox_pred_list,
                                           mlvl_anchors, img_shapes,
                                           scale_factors, cfg, rescale)
        else:
            result_list = self._get_bboxes(cls_score_list, bbox_pred_list,
                                           mlvl_anchors, img_shapes,
                                           scale_factors, cfg, rescale,
                                           with_nms)
        return result_list

代码的核心在最后一行_get_bboxes函数

    def _get_bboxes(self,
                    cls_score_list,
                    bbox_pred_list,
                    mlvl_anchors,
                    img_shapes,
                    scale_factors,
                    cfg,
                    rescale=False,
                    with_nms=True):
        cfg = self.test_cfg if cfg is None else cfg
        assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
        batch_size = cls_score_list[0].shape[0]
        # convert to tensor to keep tracing
        nms_pre_tensor = torch.tensor(
            cfg.get('nms_pre', -1),
            device=cls_score_list[0].device,
            dtype=torch.long)

        mlvl_bboxes = []
        mlvl_scores = []
        for cls_score, bbox_pred, anchors in zip(cls_score_list,
                                                 bbox_pred_list, mlvl_anchors):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            cls_score = cls_score.permute(0, 2, 3,
                                          1).reshape(batch_size, -1,
                                                     self.cls_out_channels)
            if self.use_sigmoid_cls:
                scores = cls_score.sigmoid()
            else:
                scores = cls_score.softmax(-1)
            bbox_pred = bbox_pred.permute(0, 2, 3,
                                          1).reshape(batch_size, -1, 4)
            anchors = anchors.expand_as(bbox_pred)
            # Always keep topk op for dynamic input in onnx
            if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
                                       or scores.shape[-2] > nms_pre_tensor):
                from torch import _shape_as_tensor
                # keep shape as tensor and get k
                num_anchor = _shape_as_tensor(scores)[-2].to(
                    nms_pre_tensor.device)
                nms_pre = torch.where(nms_pre_tensor < num_anchor,
                                      nms_pre_tensor, num_anchor)

                # Get maximum scores for foreground classes.
                if self.use_sigmoid_cls:
                    max_scores, _ = scores.max(-1)
                else:
                    # remind that we set FG labels to [0, num_class-1]
                    # since mmdet v2.0
                    # BG cat_id: num_class
                    max_scores, _ = scores[..., :-1].max(-1)

                _, topk_inds = max_scores.topk(nms_pre)
                batch_inds = torch.arange(batch_size).view(
                    -1, 1).expand_as(topk_inds)
                anchors = anchors[batch_inds, topk_inds, :]
                bbox_pred = bbox_pred[batch_inds, topk_inds, :]
                scores = scores[batch_inds, topk_inds, :]

            bboxes = self.bbox_coder.decode(
                anchors, bbox_pred, max_shape=img_shapes)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)

        batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
        if rescale:
            batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
                scale_factors).unsqueeze(1)
        batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)

        # Set max number of box to be feed into nms in deployment
        deploy_nms_pre = cfg.get('deploy_nms_pre', -1)
        if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export():
            # Get maximum scores for foreground classes.
            if self.use_sigmoid_cls:
                max_scores, _ = batch_mlvl_scores.max(-1)
            else:
                # remind that we set FG labels to [0, num_class-1]
                # since mmdet v2.0
                # BG cat_id: num_class
                max_scores, _ = batch_mlvl_scores[..., :-1].max(-1)
            _, topk_inds = max_scores.topk(deploy_nms_pre)
            batch_inds = torch.arange(batch_size).view(-1,
                                                       1).expand_as(topk_inds)
            batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds]
            batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds]
        if self.use_sigmoid_cls:
            # Add a dummy background class to the backend when using sigmoid
            # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
            # BG cat_id: num_class
            padding = batch_mlvl_scores.new_zeros(batch_size,
                                                  batch_mlvl_scores.shape[1],
                                                  1)
            batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)

        if with_nms:
            det_results = []
            for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes,
                                                  batch_mlvl_scores):
                det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores,
                                                     cfg.score_thr, cfg.nms,
                                                     cfg.max_per_img)
                det_results.append(tuple([det_bbox, det_label]))
        else:
            det_results = [
                tuple(mlvl_bs)
                for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores)
            ]
        return det_results

函数首先设定了每一个anchor level的最大能够得到的目标数量为1000。对于某一个Anchor level,先找出每个Anchor最大的预测为非背景的分数,然后再找出前1000个预测分数最高的Anchor,并且将这些Anchor的分数和预测框坐标都保存好。对于后几层的feature map,其预测的目标数量是小于1000的,这些目标不进行排序全部保存下来。
最后保存下来的有2790个检测框,最后用multiclass_nms删除掉大部分重复的检测框。
multiclass_nms定义在mmdet/core/post_processing/bbox_nms.py中

def multiclass_nms(multi_bboxes,
                   multi_scores,
                   score_thr,
                   nms_cfg,
                   max_num=-1,
                   score_factors=None,
                   return_inds=False):
    num_classes = multi_scores.size(1) - 1
    # exclude background category
    if multi_bboxes.shape[1] > 4:
        bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
    else:
        bboxes = multi_bboxes[:, None].expand(
            multi_scores.size(0), num_classes, 4)

    scores = multi_scores[:, :-1]

    labels = torch.arange(num_classes, dtype=torch.long)
    labels = labels.view(1, -1).expand_as(scores)

    bboxes = bboxes.reshape(-1, 4)
    scores = scores.reshape(-1)
    labels = labels.reshape(-1)

    if not torch.onnx.is_in_onnx_export():
        # NonZero not supported  in TensorRT
        # remove low scoring boxes
        valid_mask = scores > score_thr
    # multiply score_factor after threshold to preserve more bboxes, improve
    # mAP by 1% for YOLOv3
    if score_factors is not None:
        # expand the shape to match original shape of score
        score_factors = score_factors.view(-1, 1).expand(
            multi_scores.size(0), num_classes)
        score_factors = score_factors.reshape(-1)
        scores = scores * score_factors

    if not torch.onnx.is_in_onnx_export():
        # NonZero not supported  in TensorRT
        inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
        bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
    else:
        # TensorRT NMS plugin has invalid output filled with -1
        # add dummy data to make detection output correct.
        bboxes = torch.cat([bboxes, bboxes.new_zeros(1, 4)], dim=0)
        scores = torch.cat([scores, scores.new_zeros(1)], dim=0)
        labels = torch.cat([labels, labels.new_zeros(1)], dim=0)

    if bboxes.numel() == 0:
        if torch.onnx.is_in_onnx_export():
            raise RuntimeError('[ONNX Error] Can not record NMS '
                               'as it has not been executed this time')
        if return_inds:
            return bboxes, labels, inds
        else:
            return bboxes, labels

    dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)

    if max_num > 0:
        dets = dets[:max_num]
        keep = keep[:max_num]

    if return_inds:
        return dets, labels[keep], keep
    else:
        return dets, labels[keep]

函数主要工作是筛选掉box的分数过低的框(阈值0.02),然后将剩余框的方框坐标、预测类别以及类别分数放到batched_nms函数中进行运算。
最终的nms计算是在编译好的模块上运算的,没有python的源代码,因此就不再往下溯源了。

4. 小结

本文利用SSD网络为例,整体上过了一遍mmdetection的训练和推理过程。对整个SSD网络的细节基本弄透彻。与其之后的one-stage算法相比,SSD由于没有FPN模块,因此feature map密度高的层可能没有提取到足够信息,导致小目标的检测不够准确。其优点为引入了多尺度预测,很好利用了各层的feature map的特征;采取了不同长宽比的Anchor,使之能够匹配不同长宽的目标;在训练的时候引入了困难样本挖掘,同时平衡了正负样本,使训练更好收敛。
文章主要是针对mmdetection代码进行流程讲解的,由于代码嵌套复杂可能看起来会有些乱,文中有描述不准或者错误的地方在所难免,欢迎交流指正。
之后有时间的话会写一下两个two-stage算法Faster-RCNN和Cascade R-CNN。

你可能感兴趣的:(机器(深度)学习,pytorch,目标检测,计算机视觉)