mmdetection——anchor_target解读

anchor_target函数解读
该函数输入参数:

"""Compute regression and classification targets for anchors.
Args:
    anchor_list (list[list]): Multi level anchors of each image.
    valid_flag_list (list[list]): Multi level valid flags of each image.
    gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
    img_metas (list[dict]): Meta info of each image.
    target_means (Iterable): Mean value of regression targets.
    target_stds (Iterable): Std value of regression targets.
    cfg (dict): RPN train configs.
Returns:
    tuple
"""

核心思路:

  1. 将每张图多个尺度anchor合并起来,得出【num_imgs,num_anchors*4】的tensor
  2. 对每张图计算出target
  3. 将得出的以图片数划分的tensors重新划分为以level的主导的tensors,使得得出的tensors和模型预测结果格式对应上

1.拼接单图多尺度anchor到一个tensor

# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
    assert len(anchor_list[i]) == len(valid_flag_list[i])
    anchor_list[i] = torch.cat(anchor_list[i])
    valid_flag_list[i] = torch.cat(valid_flag_list[i])
  1. anchor_target_single
    针对每张图片分别用anchor_target_single得出anchor target
    这儿多图操作同样用的multi_apply完成
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]

if sampling:
    assign_result, sampling_result = assign_and_sample(
        anchors, gt_bboxes, gt_bboxes_ignore, None, cfg)
else:
    bbox_assigner = build_assigner(cfg.assigner)
    assign_result = bbox_assigner.assign(anchors, gt_bboxes,
                                         gt_bboxes_ignore, gt_labels)
    bbox_sampler = PseudoSampler()
    sampling_result = bbox_sampler.sample(assign_result, anchors,
                                          gt_bboxes)

anchor_target_single主要涉及的就是assign gt and sample anchors

其中,只对有效anchor去计算target
faster rcnn这种就存在采样的操作,而retinanet这种就对所有负样本算损失,不采样,关于assign和sample的过程另说。

inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
                                   img_meta['img_shape'][:2],
                                   cfg.allowed_border)
if not inside_flags.any():
    return (None, ) * 6
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]

if sampling:
    assign_result, sampling_result = assign_and_sample(
        anchors, gt_bboxes, gt_bboxes_ignore, None, cfg)
else:
    bbox_assigner = build_assigner(cfg.assigner)
    assign_result = bbox_assigner.assign(anchors, gt_bboxes,
                                         gt_bboxes_ignore, gt_labels)
    bbox_sampler = PseudoSampler()
    sampling_result = bbox_sampler.sample(assign_result, anchors,
                                          gt_bboxes)

常见的assign为max_iou_assigner, 依据anchor与gt bbox的iou确定anchor target,assign返回的数据为一个数据类AssignResult,包含num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels,这几项。
num_gts: gt bbox数量
assigned_gt_inds: anchor对应的label。-1:无视,0:负样本,正数:gt bbox对应的index
max_overlaps:anchor与gt bbox的最大iou
labels:pos bbox对应的label

assign完了就是sample了,如果有sample操作,就按照指定的sample类型来,若像retinanet这种没有sample操作的,直接用的PseudoSampler这个采样类,这个类直接将所有有效anchor提取出来。
sampler返回的是SamplingResult对象,包含pos_inds, neg_inds, bboxes, gt_bboxes,assign_result, gt_flags
pos_inds:pos anchor的索引
neg_inds:neg anchor的索引
bboxes:anchors

当有了anchor的target gt之后还需要将bbox转换成delta,所下面代码做的就是计算pos neg anchor对应的delta和权重赋值。

num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = anchors.new_zeros(num_valid_anchors, dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)


pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
    pos_bbox_targets = bbox2delta(sampling_result.pos_bboxes,
                                  sampling_result.pos_gt_bboxes,
                                  target_means, target_stds)
    bbox_targets[pos_inds, :] = pos_bbox_targets
    bbox_weights[pos_inds, :] = 1.0
    if gt_labels is None:
        labels[pos_inds] = 1
    else:
        labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
    if cfg.pos_weight <= 0:
        label_weights[pos_inds] = 1.0
    else:
        label_weights[pos_inds] = cfg.pos_weight
if len(neg_inds) > 0:
    label_weights[neg_inds] = 1.0

最后将有效anchor填充到原来所有的anchor里

# map up to original set of anchors
if unmap_outputs:
    num_total_anchors = flat_anchors.size(0)
    labels = unmap(labels, num_total_anchors, inside_flags)
    label_weights = unmap(label_weights, num_total_anchors, inside_flags)
    bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
    bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)

你可能感兴趣的:(目标检测)