SA-SSD 代码阅读

文章目录

    • 一. SingleStageDetector
      • 1. 初始化
      • 2. 前向传递
    • 二. Auxiliary Network
      • 1. 生成 label
      • 2. loss 构建
    • 三. SSDRotateHead
      • 1. 前向传递
      • 2. loss 构建
      • 3. 生成 label

关于网络的细节也可以看这篇博客,作者介绍的很详细:

  • 小白科研笔记:简析CVPR2020论文SA-SSD的网络搭建细节

一. SingleStageDetector

这个是 SA-SSD 的整体网络,由这几个部分组成:

  • backbone
  • neck
  • head
  • extra-head

在之后会详细分析每个部分,先来看一下整体的网络:(先看一下有哪些函数,具体的函数内容先省去了)

class SingleStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
                       MaskTestMixin):

    def __init__(self,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 extra_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SingleStageDetector, self).__init__()

        self.backbone = builder.build_backbone(backbone)

        if neck is not None:
            self.neck = builder.build_neck(neck)
        else:
            raise NotImplementedError

        if bbox_head is not None:
            self.rpn_head = builder.build_single_stage_head(bbox_head)

        if extra_head is not None:
            self.extra_head = builder.build_single_stage_head(extra_head)

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.init_weights(pretrained)

    @property
    def with_rpn(self):
        return hasattr(self, 'rpn_head') and self.rpn_head is not None

    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)

    def merge_second_batch(self, batch_args):

        return ret

    def forward_train(self, img, img_meta, **kwargs):

        return losses

    def forward_test(self, img, img_meta, **kwargs):

        return results

1. 初始化

代码分析:

    def __init__(self,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 extra_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SingleStageDetector, self).__init__()

        # 初始化 Backbone
        self.backbone = builder.build_backbone(backbone)
		
		# 初始化 neck
        if neck is not None:
            self.neck = builder.build_neck(neck)
        else:
            raise NotImplementedError
            
		# 初始化 head
        if bbox_head is not None:
            self.rpn_head = builder.build_single_stage_head(bbox_head)
		
		# 初始化 extra-head
        if extra_head is not None:
            self.extra_head = builder.build_single_stage_head(extra_head)
		
		# 传入 cfg 中的参数
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
		
		# 初始化权重
        self.init_weights(pretrained)

初始化部分都是一样的,点进去这些函数,就会发现其实都是通过 cfg 文件中的配置 分别初始化这些部分,最后都会进到这个 obj_from_dict 函数。

# 根据字典型变量info去指定初始化一个parrent类对象
# 说白了,就是字典型变量中储存了类的初始化变量。核心调用是getattr
# 总之,obj_from_dict是一种做指定初始化的功能函数
def obj_from_dict(info, parent=None, default_args=None):
    """Initialize an object from dict.

    The dict must contain the key "type", which indicates the object type, it
    can be either a string or type, such as "list" or ``list``. Remaining
    fields are treated as the arguments for constructing the object.

    Args:
        info (dict): Object types and arguments.
        parent (:class:`module`): Module which may containing expected object
            classes.
        default_args (dict, optional): Default arguments for initializing the
            object.

    Returns:
        any type: Object built from the dict.
    """
    # 首先,判断info是不是字典,而且里面必须包含type关键字
    # 默认参数也要检查是字典或者为None
    assert isinstance(info, dict) and 'type' in info
    assert isinstance(default_args, dict) or default_args is None

    args = info.copy()
    obj_type = args.pop('type')
    if mmcv.is_str(obj_type):
        if parent is not None:
            obj_type = getattr(parent, obj_type)
        else:
            obj_type = sys.modules[obj_type]
    elif not isinstance(obj_type, type):
        raise TypeError('type must be a str or valid type, but '
                        f'got {type(obj_type)}')
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_type(**args) # 传入arg里面的参数 相当于实例化了这个类

刚开始看这个函数没整明白,细细看了一下,起始就是根据 cfg 中 设置,找到所要初始化的类,然后再传进去 cfg 中的参数,举个栗子:

neck=dict(
    type='SpMiddleFHD',
    output_shape=[40, 1600, 1408],
    num_input_features=4,
    num_hidden_features=64 * 5,
),

这是初始化 neck ,cfg 文件中的配置,首先根据 type='SpMiddleFHD' 找到 SpMiddleFHD 这个类,然后再根据 cfg 中的 参数 实例化这个类。此时

return obj_type(**args) 

就相当于:

return SpMiddleFHD(output_shape=[40, 1600, 1408], num_input_features=4, num_hidden_features=64 * 5)

ok, 其他的部分的初始化以此类推,都是这么实现的。应该本身代码是基于 mmdetection 实现的,然后 mmdetection 中就是这么实现的,恩,看懂了就行,以后自己再写代码的时候,也可以这么写,也很方便简洁。

2. 前向传递

然后看一下前向传递的函数:注释也在代码里面了

# img.shape [B, 3, 384, 1248]
# img_meta: dict
#          img_meta[0]:
#                      img_shape : tuple (375, 1242, 3)
#                      sample_idx
#                      calib
# kwargs:
#       1. anchors           list: len(anchors)      = B
#       2. voxels            list: len(voxels)       = B
#       3. coordinates       list: len(coordinates)  = B
#       4. num_points        list: len(num_points)   = B
#       5. anchor_mask       list: len(anchor_mask)  = B
#       6. gt_labels         list: len(gt_labels)    = B
#       7. gt_bboxes         list: len(gt_bboxes)    = B

def forward_train(self, img, img_meta, **kwargs):

    # --------------------------------------------------------------------------
    # from mmdet.datasets.kitti_utils import draw_lidar
    # f = draw_lidar(kwargs["voxels"][0].cpu().numpy(), show=True) # 显示 所有点云
    # --------------------------------------------------------------------------

    batch_size = len(img_meta) # B

    ret = self.merge_second_batch(kwargs)

    # vx  就是 ret['voxels']
    vx = self.backbone(ret['voxels'], ret['num_points'])

    # x.shape     = [2, 256, 200, 176]
    # conv6.shape = [2, 256, 200, 176]
    # point_misc  : tuple, shape = 3
    #             : 1. point_mean : shape [N,4] , [:,0] 是 Batch number
    #             : 2. point_cls  : shape [N,1]
    #             : 3. point_reg  : shape [N.3]
    (x, conv6), point_misc = self.neck(vx, ret['coordinates'], batch_size)

    losses = dict()

    aux_loss = self.neck.aux_loss(*point_misc, gt_bboxes=ret['gt_bboxes'])
    losses.update(aux_loss)

    # RPN forward and loss
    if self.with_rpn:

        # rpn_outs    : tuple, size = 3
        #             : 1. box_preds      : shape [N, 200, 176, 14]
        #             : 2. cls_preds      : shape [N, 200, 176,  2]
        #             : 3. dir_cls_preds  : shape [N, 200, 176,  4]
        rpn_outs = self.rpn_head(x)

        # rpn_outs    : tuple, shape = 8
        rpn_loss_inputs = rpn_outs + (ret['gt_bboxes'], ret['gt_labels'], ret['anchors'], ret['anchors_mask'], self.train_cfg.rpn)

        rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)

        losses.update(rpn_losses)

        # guided_anchors.shape :
        #                        [num_of_guided_anchors, 7]
        #                      + [num_of_gt_bboxes,      7]
        #                      ----------------------------
        #                      = [all_num,               7]
        guided_anchors = self.rpn_head.get_guided_anchors(*rpn_outs, ret['anchors'], ret['anchors_mask'], ret['gt_bboxes'], thr=0.1)
    else:
        raise NotImplementedError

    # bbox head forward and loss
    if self.extra_head:
        bbox_score = self.extra_head(conv6, guided_anchors)
        refine_loss_inputs = (bbox_score, ret['gt_bboxes'], ret['gt_labels'], guided_anchors, self.train_cfg.extra)
        refine_losses = self.extra_head.loss(*refine_loss_inputs)
        losses.update(refine_losses)

    return losses

首先传进来的参数 会经过 merge_second_batch() 这个函数,看一下:

def merge_second_batch(self, batch_args):
      ret = {}
      for key, elems in batch_args.items():
          if key in [
              'voxels', 'num_points',
          ]:
              ret[key] = torch.cat(elems, dim=0)
          elif key == 'coordinates':
              coors = []
              for i, coor in enumerate(elems): # coor.shape : torch.Size([19480, 3])
                  coor_pad = F.pad(
                      coor, [1, 0, 0, 0],
                      mode='constant',
                      value=i)                # 理解 https://blog.csdn.net/jorg_zhao/article/details/105295686
                  coors.append(coor_pad)
              ret[key] = torch.cat(coors, dim=0)
          elif key in [
              'img_meta', 'gt_labels', 'gt_bboxes',
          ]:
              ret[key] = elems
          else:
              ret[key] = torch.stack(elems, dim=0)
      return ret

主要就是根据 keybatch 合并了,这个没什么问题,注意有这么一步:

coor_pad = F.pad( 
           coor, [1, 0, 0, 0],
           mode='constant',
           value=i)               
coors.append(coor_pad)

这里 F.pad 的用法见: F.pad

目的就是给 coordinates 多加一个维度 (eg: i = 0,1, …),来保存 Batch

然后就是构建 loss 了,总共由三部分组成 :

l o s s _ a l l = a u g _ l o s s + r p n _ l o s s + e x t r a _ h e a d _ l o s s loss\_all =aug\_loss + rpn\_loss + extra\_head\_loss loss_all=aug_loss+rpn_loss+extra_head_loss

之后每部分 loss 的 具体组成 在后面也会具体分析。

二. Auxiliary Network

1. 生成 label

在 Auxiliary Network 中, 需要分割出 前景点 和 背景点,首先需要生成前景点和背景点的 label

def pts_in_boxes3d(pts, boxes3d):
    N = len(pts) 
    M = len(boxes3d)
    pts_in_flag = torch.IntTensor(M, N).fill_(0)
    reg_target = torch.FloatTensor(N, 3).fill_(0)
    points_op_cpu.pts_in_boxes3d(pts.contiguous(), boxes3d.contiguous(), pts_in_flag, reg_target)
    return pts_in_flag, reg_target

其中:

pts_in_flag : [M, N] , pts 在 bbox 中,则 mask = 1 

疑惑 :reg_target : [N, 3], 值是什么?又是怎么得到的?

需要解决上面一个疑惑,就需要弄懂这个函数 points_op_cpu.pts_in_boxes3d 。这个函数在 mmdet / ops / points_op / src / points_op.cpp 中,来看一下:

int pts_in_boxes3d_cpu(at::Tensor pts, at::Tensor boxes3d, at::Tensor pts_flag, at::Tensor reg_target){
    // param pts: (N, 3)
    // param boxes3d: (M, 7)  [x, y, z, h, w, l, ry]
    // param pts_flag: (M, N)
    // param reg_target: (N, 3), center offsets

    CHECK_CONTIGUOUS(pts_flag);
    CHECK_CONTIGUOUS(pts);
    CHECK_CONTIGUOUS(boxes3d);
    CHECK_CONTIGUOUS(reg_target);

    long boxes_num = boxes3d.size(0);
    long pts_num = pts.size(0);

    int * pts_flag_flat = pts_flag.data<int>();
    float * pts_flat = pts.data<float>();
    float * boxes3d_flat = boxes3d.data<float>();
    float * reg_target_flat = reg_target.data<float>();

    // memset(assign_idx_flat, -1, boxes_num * pts_num * sizeof(int));
    // memset(reg_target_flat, 0, pts_num * sizeof(float));
    
	// 这里相当于把 tensor 给展开了遍历 (或者说铺平了?更好理解。懂就好)  
	
    int i, j, cur_in_flag;
    for (i = 0; i < boxes_num; i++){
        for (j = 0; j < pts_num; j++){
            cur_in_flag = pt_in_box3d_cpu(pts_flat[j * 3], pts_flat[j * 3 + 1], pts_flat[j * 3 + 2], boxes3d_flat[i * 7],
                                          boxes3d_flat[i * 7 + 1], boxes3d_flat[i * 7 + 2], boxes3d_flat[i * 7 + 3],
                                          boxes3d_flat[i * 7 + 4], boxes3d_flat[i * 7 + 5], boxes3d_flat[i * 7 + 6]);
            pts_flag_flat[i * pts_num + j] = cur_in_flag;
            if(cur_in_flag==1){
                reg_target_flat[j*3] = pts_flat[j*3] - boxes3d_flat[i*7];
                reg_target_flat[j*3+1] = pts_flat[j*3+1] - boxes3d_flat[i*7+1];
                reg_target_flat[j*3+2] = pts_flat[j*3+2] - (boxes3d_flat[i*7+2] + boxes3d_flat[i*7+3] / 2.0);
            }
        }
    }
    return 1;
}

其实已经可以大致理解这个函数在干啥了,通过两层循环遍历,判断点云中的所有点是否在所给定的 bbox 中,如果在 bbox 中, 那就将 该点的值 - bbox 中心点的值 ,就是 reg_target, 用公式表示就是:

r e g _ t a r g e t = P i ( x , y , z ) − P c e n t e r ( x , y , z ) reg\_target =P_{i}(x, y, z) -P_{center}(x,y,z) reg_target=Pi(x,y,z)Pcenter(x,y,z)

ok,上面的疑问也解开了

2. loss 构建

SA-SSD 代码阅读_第1张图片

三. SSDRotateHead

这部分是整个网络的 head 部分,先简单列出来,然后来具体分析一下。

class SSDRotateHead(nn.Module):

    def __init__(self,
                 num_class=1,
                 num_output_filters=768,
                 num_anchor_per_loc=2,
                 use_sigmoid_cls=True,
                 encode_rad_error_by_sin=True,
                 use_direction_classifier=True,
                 box_coder='GroundBox3dCoder',
                 box_code_size=7,
                 ):
        super(SSDRotateHead, self).__init__()
        self._num_class = num_class
        self._num_anchor_per_loc = num_anchor_per_loc
        self._use_direction_classifier = use_direction_classifier
        self._use_sigmoid_cls = use_sigmoid_cls
        self._encode_rad_error_by_sin = encode_rad_error_by_sin
        self._use_direction_classifier = use_direction_classifier
        self._box_coder = getattr(boxCoders, box_coder)()
        self._box_code_size = box_code_size
        self._num_output_filters = num_output_filters

        if use_sigmoid_cls: # True
            num_cls = num_anchor_per_loc * num_class # 2 * 1
        else:
            num_cls = num_anchor_per_loc * (num_class + 1)

        self.conv_cls = nn.Conv2d(num_output_filters, num_cls, 1)
        self.conv_box = nn.Conv2d(
            num_output_filters, num_anchor_per_loc * box_code_size, 1)
        if use_direction_classifier:
            self.conv_dir_cls = nn.Conv2d(
                num_output_filters, num_anchor_per_loc * 2, 1)

    def add_sin_difference(self, boxes1, boxes2):
    
    def get_direction_target(self, anchors, reg_targets, use_one_hot=True):
    
    def prepare_loss_weights(self, labels,
                             pos_cls_weight=1.0,
                             neg_cls_weight=1.0,
                             loss_norm_type='NormByNumPositives',
                             dtype=torch.float32):
                             
    def create_loss(self,
                    box_preds,                        # torch.Size([2, 200, 176, 14])
                    cls_preds,                        # torch.Size([2, 200, 176, 2])
                    cls_targets,                      # torch.Size([2, 70400])
                    cls_weights,                      # torch.Size([2, 70400])
                    reg_targets,                      # torch.Size([2, 70400, 7])
                    reg_weights,                      # torch.Size([2, 70400])
                    num_class,                        # 1
                    use_sigmoid_cls=True,             # True
                    encode_rad_error_by_sin=True,     # True
                    box_code_size=7):                 # 7

    def forward(self, x):               
    	
    def get_guided_anchors(self, box_preds, cls_preds, dir_cls_preds, anchors, anchors_mask, gt_bboxes, thr=.1):

1. 前向传递

首先看一下 前向传递 forward 函数 :

def forward(self, x):               # torch.Size([2, 256, 200, 176])

    box_preds = self.conv_box(x)    
    cls_preds = self.conv_cls(x)    
    # [N, C, y(H), x(W)]
    
    box_preds = box_preds.permute(0, 2, 3, 1).contiguous()             # torch.Size([2, 200, 176, 14])
    cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()             # torch.Size([2, 200, 176, 2])

    if self._use_direction_classifier:
        dir_cls_preds = self.conv_dir_cls(x)
        dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()  # torch.Size([2, 200, 176, 4])

    return box_preds, cls_preds, dir_cls_preds

输入就是经过 backbone 得到的 feature map , 然后分成两支,分别预测bbox和物体的类别。

2. loss 构建

看一下,loss 是怎么构建的:

# input
# box_preds   : torch.Size([2, 200, 176, 14])
# cls_preds   : torch.Size([2, 200, 176, 2])
# gt_bboxes   : list:len(gt_bboxes) = B , gt_bboxes[0].shape = torch.Size([num_of_gt_bboxes, 7])
# anchor      : torch.Size([2, 70400, 7])
# anchor_mask : torch.Size([2, 70400])
# cfg         : from car_cfg.py / train_cfg
def loss(self, box_preds, cls_preds, dir_cls_preds, gt_bboxes, gt_labels, anchors, anchors_mask, cfg):

    batch_size = box_preds.shape[0]

    # ADD----------------------------------------------------------------------------------------------
    add_for_test = False
    add_for_pkl  = False
	
	# for show gt_bboxes
    if add_for_test == True:
        bbox3d_for_test = gt_bboxes[0].cpu().numpy()
        draw_gt_boxes3d_for_test(center_to_corner_box3d(bbox3d_for_test), draw_text=True, show=True)
	
	# for vis anchor
    if add_for_pkl == True:
        pkl_data = {}
        pkl_data['anchors'] = anchors
        pkl_data['anchors_mask'] = anchors_mask

        import pickle
        with open("/home/seivl/pkl_data.pkl", 'wb') as fo:
            pickle.dump(pkl_data, fo)
    #-----------------------------------------------------------------------------------------------

    # 第一个 create_target_torch 是函数
    # 后面变量相当于传参数 进这个函数
    # targets 是 reg 的 target 
    labels, targets, ious = multi_apply(create_target_torch,
                                        anchors, gt_bboxes,
                                        anchors_mask, gt_labels,
                                        similarity_fn=getattr(iou3d_utils, cfg.assigner.similarity_fn)(),
                                        box_encoding_fn = second_box_encode,
                                        matched_threshold=cfg.assigner.pos_iou_thr,
                                        unmatched_threshold=cfg.assigner.neg_iou_thr,
                                        box_code_size=self._box_code_size)


    labels = torch.stack(labels,)
    targets = torch.stack(targets)

	# 生成 cls 和 reg 的权重
    cls_weights, reg_weights, cared = self.prepare_loss_weights(labels)
	
	# 生成 cls 的 target
    cls_targets = labels * cared.type_as(labels)

	# 构建 loss 
	# 具体解析见下
    loc_loss, cls_loss = self.create_loss(
        box_preds=box_preds,
        cls_preds=cls_preds,
        cls_targets=cls_targets,
        cls_weights=cls_weights,
        reg_targets=targets,
        reg_weights=reg_weights,
        num_class=self._num_class,
        encode_rad_error_by_sin=self._encode_rad_error_by_sin,
        use_sigmoid_cls=self._use_sigmoid_cls,
        box_code_size=self._box_code_size,
    )

    loc_loss_reduced = loc_loss / batch_size
    loc_loss_reduced *= 2                     # loc_loss 的权重

    cls_loss_reduced = cls_loss / batch_size
    cls_loss_reduced *= 1

    loss = loc_loss_reduced + cls_loss_reduced

    if self._use_direction_classifier:
        # 生成与 dir_cls_preds 对应的真值 dir_labels
        dir_labels = self.get_direction_target(anchors, targets, use_one_hot=False).view(-1)
        dir_logits = dir_cls_preds.view(-1, 2)
        
        # 设置权值是为了仅仅考虑 labels > 0 的目标(即车这一类)
        weights = (labels > 0).type_as(dir_logits)
        weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0)

		# 使用交叉熵做朝向预测的误差损失函数
        dir_loss = weighted_cross_entropy(dir_logits, dir_labels,
                                          weight=weights.view(-1),
                                          avg_factor=1.)

        dir_loss_reduced = dir_loss / batch_size
        dir_loss_reduced *= .2
        loss += dir_loss_reduced

    return dict(rpn_loc_loss=loc_loss_reduced, rpn_cls_loss=cls_loss_reduced, rpn_dir_loss=dir_loss_reduced)

里面有一个很重要的函数 create_target_torch,是用来生成 label 用的, 具体分析在后面。

具体的 loss 构建函数:

def create_loss(self,
                box_preds,                        # torch.Size([2, 200, 176, 14])
                cls_preds,                        # torch.Size([2, 200, 176, 2])
                cls_targets,                      # torch.Size([2, 70400])
                cls_weights,                      # torch.Size([2, 70400])
                reg_targets,                      # torch.Size([2, 70400, 7])
                reg_weights,                      # torch.Size([2, 70400])
                num_class,                        # 1
                use_sigmoid_cls=True,             # True
                encode_rad_error_by_sin=True,     # True
                box_code_size=7):                 # 7

    batch_size = int(box_preds.shape[0])                        # B = 2

    box_preds = box_preds.view(batch_size, -1, box_code_size)   # torch.Size([2, 70400, 7])

    if use_sigmoid_cls:
        cls_preds = cls_preds.view(batch_size, -1, num_class)   # torch.Size([2, 70400, 1])
    else:
        cls_preds = cls_preds.view(batch_size, -1, num_class + 1)

    one_hot_targets = one_hot(
        cls_targets, depth=num_class + 1, dtype=box_preds.dtype) # torch.Size([2, 70400, 2])

    if use_sigmoid_cls:
        one_hot_targets = one_hot_targets[..., 1:]               # torch.Size([2, 70400, 1])
    if encode_rad_error_by_sin:
        box_preds, reg_targets = self.add_sin_difference(box_preds, reg_targets)
        # torch.Size([2, 70400, 7])
        # torch.Size([2, 70400, 7])

    loc_losses = weighted_smoothl1(box_preds, reg_targets, beta=1 / 9., \
                                   weight=reg_weights[..., None], avg_factor=1.)
    cls_losses = weighted_sigmoid_focal_loss(cls_preds, one_hot_targets, \
                                             weight=cls_weights[..., None], avg_factor=1.)

    return loc_losses, cls_losses

3. 生成 label

主要在 create_target_torch 这个函数中,注释和解析如下,

这段代码的作用 主要是为了:

  • 生成 anchor 的 label
  • bbox 回归的 target
  • 同时返回 每个 anchor 和 每个 gt_bbox 的 iou
# all_anchors          : torch.Size([70400, 7])
# gt_boxes             : torch.Size([num_of_gt_bbox, 7])
# anchor_mask          : torch.Size(70400,)
# gt_classes           : num_of_gt_bbox eg: 14
# similarity_fn        : 
# box_encoding_fn      : 
# matched_threshold    : 0.6
# unmatched_threshold  : 0.45
# positive_fraction    : None
# norm_by_num_examples : False
# box_code_size        : 7

def create_target_torch(all_anchors,
                        gt_boxes,
                        anchor_mask,
                        gt_classes,
                        similarity_fn,
                        box_encoding_fn,
                        matched_threshold=0.6,
                        unmatched_threshold=0.45,
                        positive_fraction=None,
                        rpn_batch_size=300,
                        norm_by_num_examples=False,
                        box_code_size=7):

    # torch.set_printoptions(threshold=np.inf)
	# 这个函数的作用是将 anchor_mask 映射回 anchor 
    def _unmap(data, count, inds, fill=0):

        # ----------------------------
        # data  : label
        # count : anchor.shape
        # inds  : mask
        # ---------------------------

        """ Unmap a subset of item (data) back to the original set of items (of
        size count) """
        if data.dim() == 1:
            ret = data.new_full((count,), fill)
            ret[inds] = data
        else:
            new_size = (count,) + data.size()[1:]
            ret = data.new_full(new_size, fill)
            ret[inds, :] = data
        return ret

    # value: 70400
    total_anchors = all_anchors.shape[0]

    # go
    if anchor_mask is not None:
        #inds_inside = np.where(anchors_mask)[0]  # prune_anchor_fn(all_anchors)

        # value: 22007
        anchors = all_anchors[anchor_mask, :]

        if not isinstance(matched_threshold, float):
            matched_threshold = matched_threshold[anchor_mask]
        if not isinstance(unmatched_threshold, float):
            unmatched_threshold = unmatched_threshold[anchor_mask]
    else:
        anchors = all_anchors
        #inds_inside = None

    # 22007
    num_inside = len(torch.nonzero(anchor_mask)) if anchor_mask is not None else total_anchors

    if gt_classes is None:
        gt_classes = torch.ones([gt_boxes.shape[0]], dtype=torch.int64, device=gt_boxes.device)

    # Compute anchor labels:
    # label=1 is positive, 0 is negative, -1 is don't care (ignore)
    # shape = [22007,] value = -1
    labels = torch.empty((num_inside,), dtype=torch.int64, device=gt_boxes.device).fill_(-1)
    gt_ids = torch.empty((num_inside,), dtype=torch.int64, device=gt_boxes.device).fill_(-1)

    if len(gt_boxes) > 0 and anchors.shape[0] > 0:
        # Compute overlaps between the anchors and the gt boxes overlaps
        # 计算 anchor 和 gt_bbox 的交并比 
        anchor_by_gt_overlap = similarity_fn(anchors, gt_boxes)           # torch.Size([22007, 14])

        # add for test
        # for_test_anchor_by_gt_overlap = similarity_fn(anchors[9300:9303,:], gt_boxes)


        # Map from anchor to gt box that has highest overlap
        anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(dim=1)          
        # shape:22007 
        # 计算每个 anchor 和 gt_bbox 的 iou 最大值的索引 
        # 这里的 dim = 1 就是第1个维度 22007

        # For each anchor, amount of overlap with most overlapping gt box
        anchor_to_gt_max = anchor_by_gt_overlap[torch.arange(num_inside), 
                                                anchor_to_gt_argmax]
        # 计算每个 anchor 和 gt_bbox 的 iou 最大值

        # Map from gt box to an anchor that has highest overlap
        gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(dim=0)
        # 计算每个 gt_bbox 和 anchor 的 iou 最大值的索引 
        # 这里的 dim = 0 就是第0个维度
        # shape: 14


        # For each gt box, amount of overlap with most overlapping anchor
        gt_to_anchor_max = anchor_by_gt_overlap[
            gt_to_anchor_argmax,
            torch.arange(anchor_by_gt_overlap.shape[1])]
        # 计算每个 gt_bbox 和 anchor 的 iou 最大值

        # must remove gt which doesn't match any anchor.
        empty_gt_mask = gt_to_anchor_max == 0
        gt_to_anchor_max[empty_gt_mask] = -1

        # Find all anchors that share the max overlap amount
        # (this includes many ties)
        anchors_with_max_overlap = torch.nonzero(
            anchor_by_gt_overlap == gt_to_anchor_max)[:,0]
        # 找到和 gt_bbox 有最大 iou 的 anchor
        # tensor([ 6287,  7063,  9302,  9530,  9571, 10225, 11481, 13080, 14509, 15080,
        #         15082, 15293, 18273, 18740, 21316], device='cuda:0')

        # for test
        # for_test_anchors_with_max_overlap = torch.nonzero(
        #    for_test_anchor_by_gt_overlap == gt_to_anchor_max)[:, 0]

        # Fg label: for each gt use anchors with highest overlap
        # (including ties)
        gt_inds_force = anchor_to_gt_argmax[anchors_with_max_overlap]
        # 15
        # tensor([ 6, 10, 12, 11, 13,  7,  9,  5,  3,  2,  2,  8,  1,  0,  4],
        #        device='cuda:0')
        # 找到这些 anchor 和 哪些 gt_bbox 对应

        labels[anchors_with_max_overlap] = gt_classes[gt_inds_force] # 做对应的label 最大 iou 的 anchoor 置为 1
        gt_ids[anchors_with_max_overlap] = gt_inds_force             # 保存 对应的 gt 的 序号

        # Fg label: above threshold IOU
        pos_inds = anchor_to_gt_max >= matched_threshold             # 找所有 anchor 大于阈值的
        gt_inds = anchor_to_gt_argmax[pos_inds]                      # 记录这些 anchor 对应 gt_bbox 的下标
        # 有 67 个 ,anchor 和 gt_bbox 的 iou 大于阈值
        # tensor([ 6,  6,  6,  6,  6,  6, 10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 11, 11,
        #         12, 11, 11, 13, 13, 11, 13, 13,  7,  7,  7,  7,  7,  9,  9,  9,  9,  5,
        #          5,  5,  5,  5,  3,  3,  3,  3,  2,  2,  2,  2,  8,  8,  8,  8,  1,  1,
        #          1,  1,  1,  0,  0,  0,  0,  0,  4,  4,  4,  4,  4], device='cuda:0')
        labels[pos_inds] = gt_classes[gt_inds]                        # 对应的 label 设置为 1
        gt_ids[pos_inds] = gt_inds                                    # 保存 对应的 gt 的 序号

        # bg_inds = np.where(anchor_to_gt_max < unmatched_threshold)[0]
        bg_inds = torch.nonzero(anchor_to_gt_max < unmatched_threshold)[:, 0]
        # 找到 小于阈值的 anchor 的 index
    else:
        bg_inds = torch.arange(num_inside)

    #fg_inds = np.where(labels > 0)[0]
    fg_inds = torch.nonzero(labels > 0)[:, 0]
    # 找到所有前景 anchor 的 index
    # tensor([ 6283,  6285,  6287,  6289,  6291,  6498,  6852,  6854,  7061,  7063,
    #          7268,  7270,  8883,  9094,  9300,  9302,  9324,  9326,  9508,  9530,
    #          9532,  9571,  9573,  9736,  9777,  9779,  9827, 10028, 10225, 10227,
    #         10424, 11481, 11483, 11757, 11759, 13078, 13080, 13082, 13084, 13366,
    #         14267, 14509, 14511, 14750, 15078, 15080, 15082, 15084, 15291, 15293,
    #         15295, 15553, 18009, 18269, 18271, 18273, 18275, 18493, 18495, 18738,
    #         18740, 18742, 21312, 21314, 21316, 21318, 21389], device='cuda:0')

    # subsample positive labels if we have too many
    if positive_fraction is not None:
        num_fg = int(positive_fraction * rpn_batch_size)
        if len(fg_inds) > num_fg:
            disable_inds = npr.choice(
                fg_inds, size=(len(fg_inds) - num_fg), replace=False)
            labels[disable_inds] = -1
            #fg_inds = np.where(labels > 0)[0]
            fg_inds = torch.where(labels > 0)[:, 0]

        # subsample negative labels if we have too many
        # (samples with replacement, but since the set of bg inds is large most
        # samples will not have repeats)
        num_bg = rpn_batch_size - np.sum(labels > 0)
        # print(num_fg, num_bg, len(bg_inds) )
        if len(bg_inds) > num_bg:
            enable_inds = bg_inds[npr.randint(len(bg_inds), size=num_bg)]
            labels[enable_inds] = 0
    else:
        if len(gt_boxes) == 0 or anchors.shape[0] == 0:
            labels[:] = 0
        else:
            labels[bg_inds] = 0   # 背景点的 label 设置为 0
            # re-enable anchors_with_max_overlap
            labels[anchors_with_max_overlap] = gt_classes[gt_inds_force]

	# 生成 target
    bbox_targets = torch.zeros(
        (num_inside, box_code_size), dtype=all_anchors.dtype, device=gt_boxes.device) # torch.Size([22007, 7])
	
	# 对前景的 anchor 进行编码 
    if len(gt_boxes) > 0 and anchors.shape[0] > 0:
        bbox_targets[fg_inds, :] = box_encoding_fn(
            gt_boxes[anchor_to_gt_argmax[fg_inds], :], anchors[fg_inds, :])
    # bbox_targets[fg_inds, :].shape : torch.Size([67, 7])

    bbox_outside_weights = torch.zeros((num_inside,), dtype=all_anchors.dtype, device=gt_boxes.device)

    # uniform weighting of examples (given non-uniform sampling)
    if norm_by_num_examples:
        num_examples = torch.sum(labels >= 0)  # neg + pos
        num_examples = np.maximum(1.0, num_examples)
        bbox_outside_weights[labels > 0] = 1.0 / num_examples
    else:
        bbox_outside_weights[labels > 0] = 1.0

    # Map up to original set of anchors
    if anchor_mask is not None:
        labels = _unmap(labels, total_anchors, anchor_mask, fill=-1)
        bbox_targets = _unmap(bbox_targets, total_anchors, anchor_mask, fill=0)

    return (labels, bbox_targets, anchor_to_gt_max)
    # labels.shape       : torch.Size([70400,])
    # bbox_targets.shape : torch.Size([70400, 7])
    # anchor_to_gt_max   : 22007 

	# 关于 label
	# 前景是 1
    # 背景是 0
    # 没用的是 -1

ok 未完待续。

你可能感兴趣的:(目标检测)