mmdetection——anchor_head解读

mmdet中anchor_head为预测分支的基类,包含了_inti_layers, init_weights, forward_single,forward, get_anchors, loss,get_bboxes这些功能,囊括了训练用到的loss计算以及预测用到的get_bboxes方法。

1、forward_single and forward

单尺度预测,得出分类,边框预测

def forward_single(self, x):
    cls_score = self.conv_cls(x)
    bbox_pred = self.conv_reg(x)
    return cls_score, bbox_pred

包含FPN结构涉及到多尺度预测,作者设计一个多输入处理方法multi_apply,该方法核心就是针对输入list每个元素依据func处理得出结果,得出的结果是[(cls,bbox),(cls,bbox)]这样的格式,最后再通过zip做一下同种类别预测的合并操作,输出([cls1,cls2],[bbox1,bbox2])

def multi_apply(func, *args, **kwargs):
    pfunc = partial(func, **kwargs) if kwargs else func
    map_results = map(pfunc, *args)
    return tuple(map(list, zip(*map_results)))

2、 loss

算loss,涉及到anchor的生成,以及anchor target的生成,损失函数
前两部分代码如下

anchor_list, valid_flag_list = self.get_anchors(
    featmap_sizes, img_metas, device=device)

anchor生成之前讲过了,这儿得出的anchor_list是每张图多个尺度的所有anchor,shape为[num_imgs,num_levels, num_anchors*4].valid_flag_list指出哪些anchor是合法的。

label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
    anchor_list,
    valid_flag_list,
    gt_bboxes,
    img_metas,
    self.target_means,
    self.target_stds,
    cfg,
    gt_bboxes_ignore_list=gt_bboxes_ignore,
    gt_labels_list=gt_labels,
    label_channels=label_channels,
    sampling=self.sampling)
if cls_reg_targets is None:
    return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
    num_total_pos + num_total_neg if self.sampling else num_total_pos)

接下来的重要就是生成anchor的target了,这里又涉及到assign和sampler两个操作,这块单独再解读一下。cls_reg_targets包含anchor的label,weight,bbox targets等这些信息,这些tensor都被转换成以level为主的tensor,如【num_levels,batch_size,num_anchors…】

有了target就可以算loss了,这儿loss也是用multi_apply对不同level分别算,最后汇总。

常见的cls用交叉熵损失,回归用smooth L1损失。

def loss_single(self, cls_score, bbox_pred, labels, label_weights,
                bbox_targets, bbox_weights, num_total_samples, cfg):
    # classification loss
    labels = labels.reshape(-1)
    label_weights = label_weights.reshape(-1)
    cls_score = cls_score.permute(0, 2, 3,
                                  1).reshape(-1, self.cls_out_channels)
    loss_cls = self.loss_cls(
        cls_score, labels, label_weights, avg_factor=num_total_samples)
    # regression loss
    bbox_targets = bbox_targets.reshape(-1, 4)
    bbox_weights = bbox_weights.reshape(-1, 4)
    bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
    loss_bbox = self.loss_bbox(
        bbox_pred,
        bbox_targets,
        bbox_weights,
        avg_factor=num_total_samples)
    return loss_cls, loss_bbox

3、get_bboxes
get_bboxes将模型输出转换成预测所得的boxes和labels
输入参数描述如下

Transform network output for a batch into labeled boxes.

Args:
    cls_scores (list[Tensor]): Box scores for each scale level
        Has shape (N, num_anchors * num_classes, H, W)
    bbox_preds (list[Tensor]): Box energies / deltas for each scale
        level with shape (N, num_anchors * 4, H, W)
    img_metas (list[dict]): size / scale info for each image
    cfg (mmcv.Config): test / postprocessing configuration
    rescale (bool): if True, return boxes in original image space

Returns:
    list[tuple[Tensor, Tensor]]: each item in result_list is 2-tuple.
        The first item is an (n, 5) tensor, where the first 4 columns
        are bounding box positions (tl_x, tl_y, br_x, br_y) and the
        5-th column is a score between 0 and 1. The second item is a
        (n,) tensor where each item is the class index of the
        corresponding box.
  1. 生成多尺度的anchor
  2. 遍历所有图片,对每一张图将多尺度预测结果通过get_bboxes_single转换成最终预测结果。在get_bboxes_single函数,结合多尺度anchor和预测的偏移量得出预测bboxes,函数也会涉及到nms等后处理操作。

你可能感兴趣的:(目标检测)