【MMDet Note】MMDetection中AnchorGenerator代码理解与解读

文章目录

  • 前言
  • 一、总概
  • 二、代码解读
    • 1.AnchorGenerator类
    • 2. 属性@property
    • 3.gen_base_anchors方法
    • 4. gen_single_level_base_anchors方法
    • 5. grid_priors方法
    • 6. single_level_grid_priors方法
    • 7.valid_flags与single_level_valid_flags方法
  • 总结


前言

mmdetection/mmdet/core/anchor/anchor_generator.py中AnchorGenerator类的关键代码解读。


一、总概

AnchorGenerator类主要目的是为了生成anchor-base Detector所需要的anchor_box。

通过【gen_base_anchors】方法生成单个anchor点的9种(3种尺寸、3种宽高比)基anchor_box,调用【grid_priors】方法将这9种基anchor_box再原图的尺寸上进行广播,得到一个list,其中包括所有原图尺寸上的anchor_box位置信息(左上角坐标与右下角坐标)。
【MMDet Note】MMDetection中AnchorGenerator代码理解与解读_第1张图片

二、代码解读

1.AnchorGenerator类

【MMDet Note】MMDetection中AnchorGenerator代码理解与解读_第2张图片

@PRIOR_GENERATORS.register_module()
class AnchorGenerator:
    def __init__(self,
                 strides,                       # 例:[8, 16, 32, 64, 128]
                 ratios,                        # 例:anchor的三种宽高比[0.5, 1.0, 2.0]
                 scales=None,
                 base_sizes=None,
                 scale_major=True,
                 octave_base_scale=None,        # 例:4
                 scales_per_octave=None,        # 例:3
                 centers=None,
                 center_offset=0.):
        # check center and center_offset
        if center_offset != 0:
            assert centers is None, 'center cannot be set when center_offset' \
                                    f'!=0, {centers} is given.'
        if not (0 <= center_offset <= 1):
            raise ValueError('center_offset should be in range [0, 1], '
                             f'{center_offset} is given.')
        if centers is not None:
            assert len(centers) == len(strides), \
                'The number of strides should be the same as centers, got ' \
                f'{strides} and {centers}'
 
        # self.strides = [(8,8),(16,16),(32,32),(64,64),(128,128)]
        self.strides = [_pair(stride) for stride in strides]
        # base_sizes = [8, 16, 32, 64, 128]
        self.base_sizes = [min(stride) for stride in self.strides
                           ] if base_sizes is None else base_sizes
        assert len(self.base_sizes) == len(self.strides), \
            'The number of strides should be the same as base sizes, got ' \
            f'{self.strides} and {self.base_sizes}'
 
        # octave_base_scale、scales_per_octave这两个参数和scales不能共存
        assert ((octave_base_scale is not None
                 and scales_per_octave is not None) ^ (scales is not None)), \
            'scales and octave_base_scale with scales_per_octave cannot' \
            ' be set at the same time'
        if scales is not None:
            self.scales = torch.Tensor(scales)
            
        # 通过octave_base_scale与scales_per_octave自动计算得到scale
        # self.scales = octave_base_scale * [2^0, 2^(1/3), 2^(2/3)] = [4,5,6]
        elif octave_base_scale is not None and scales_per_octave is not None:
            octave_scales = np.array(
                [2**(i / scales_per_octave) for i in range(scales_per_octave)])
            scales = octave_scales * octave_base_scale
            self.scales = torch.Tensor(scales)
        else:
            raise ValueError('Either scales or octave_base_scale with '
                             'scales_per_octave should be set')
        
        # 最终的值
        self.octave_base_scale = octave_base_scale        # 4
        self.scales_per_octave = scales_per_octave        # 3
        self.ratios = torch.Tensor(ratios)                # [0.5, 1, 2]
        self.scale_major = scale_major                    # True
        self.centers = centers                            # None
        self.center_offset = center_offset                # 0
        self.base_anchors = self.gen_base_anchors()
        # self.scales = [4,5,6]
        # self.strides = [(8,8),(16,16),(32,32),(64,64),(128,128)]

2. 属性@property

    @property
    def num_base_anchors(self):	
        """list[int]: total number of base anchors in a feature grid"""
        return self.num_base_priors

    @property
    def num_base_priors(self):		# 每一个level下的base_anchors数量
        """list[int]: The number of priors (anchors) at a point
        on the feature grid"""
        return [base_anchors.size(0) for base_anchors in self.base_anchors]

    @property
    def num_levels(self):			# level的个数
        """int: number of feature levels that the generator will be applied"""
        return len(self.strides)

3.gen_base_anchors方法

【MMDet Note】MMDetection中AnchorGenerator代码理解与解读_第3张图片

def gen_base_anchors(self):
    """
    产生base_anchors,也就是单个anchor上的9种(例)不同尺寸与宽高比的anchor_box
 
    Returns:
    list(torch.Tensor): 每一个特征图尺寸下的基anchor_box组成的list # len(list) = len(self.stride)
    """
 
    multi_level_base_anchors = []               # 存储每个特征尺度下的base_anchors
    for i, base_size in enumerate(self.base_sizes):  # 在每个特征尺度下生成base_anchors
        center = None
        if self.centers is not None:
            center = self.centers[i]
        multi_level_base_anchors.append(
            # 调用gen_single_level_base_anchors方法,产生当前特征尺度下的base_anchors
            self.gen_single_level_base_anchors(         
                    base_size,                      # 8 / 16 / 32 /64 /128(for循环变量)
                    scales=self.scales,             # [4,5,6]
                    ratios=self.ratios,             # [0.5,1,2]
                    center=center))                 # None
    return multi_level_base_anchors
    # multi_level_base_anchors = [[stride1_base_anchors], [stride2_base_anchors], ...]

4. gen_single_level_base_anchors方法

【MMDet Note】MMDetection中AnchorGenerator代码理解与解读_第4张图片

def gen_single_level_base_anchors(self,
                                  base_size,        # 8 (以8为例)
                                  scales,           # [4,5,6]
                                  ratios,           # [0.5,1,2]
                                  center=None):
    """Generate base anchors of a single level.
    """
 
    w = base_size   # w = 8
    h = base_size   # h = 8
    if center is None:
        x_center = self.center_offset * w   # 0
        y_center = self.center_offset * h   # 0
    else:
        x_center, y_center = center
 
    # h_ratios:w_ratios = [0.5:1, 1:1, 2:1]
    h_ratios = torch.sqrt(ratios)        
    w_ratios = 1 / h_ratios
        
    if self.scale_major:   # self.scale_major = True
        # 由以下公式计算得到9个ws与9个hs
        ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
        hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
    else:
        ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
        hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
 
    # [center_x,center_y,w,h] --> [xmin, ymin, xmax,ymax]
    base_anchors = [
            x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
            y_center + 0.5 * hs
        ]
 
    # 使用torch.stack改变下形状
    base_anchors = torch.stack(base_anchors, dim=-1)
 
    return base_anchors
 
    """以2个anchor为例
    base_size, scales, ratios = 8, [4,6], 1
    w, h, h_ratios, w_ratios = 8, 8, 1, 1
    ws = 8 * 1 * [4, 6] = [32, 48]
    hs = 8 * 1 * [4, 6] = [32, 48]
    base_anchors = [[-16, -24],
                    [-16, -24],
                    [16, 24],
                    [16, 24]]
    # torch.stack之后
    base_anchors = [[-16., -16.,  16.,  16.],
                    [-24., -24.,  24.,  24.]]
    """

5. grid_priors方法

该方法与【gen_base_anchors方法】的生成方式类似,区别是:1、该方法是后期调用使用的,而【gen_base_anchors方法】是在生成AnchorGenerator类时自动调用的。2、返回的列表内容不同,该方法返回每个特征图上相对于原图的所有anchor_box的位置。

def grid_priors(self, featmap_sizes, dtype=torch.float32, device='cuda'):
    """Generate grid anchors in multiple feature levels.
    """
 
    assert self.num_levels == len(featmap_sizes)       
    multi_level_anchors = []
    for i in range(self.num_levels):
        anchors = self.single_level_grid_priors(
            featmap_sizes[i], level_idx=i, dtype=dtype, device=device)
        multi_level_anchors.append(anchors)
    return multi_level_anchors
    # multi_level_anchors = [[level1_anchorboxs], [level2_anchorboxs], ...]

6. single_level_grid_priors方法

【MMDet Note】MMDetection中AnchorGenerator代码理解与解读_第5张图片

    def single_level_grid_priors(self,
                                 featmap_size,
                                 level_idx,
                                 dtype=torch.float32,
                                 device='cuda'):
        """Generate grid anchors of a single level.
        也就是生成每一个特征图尺度下所有的anchor_boxs,其坐标是相对于原图尺寸的
        """
        
        # 得到当前level_idx下9个(例)不同尺寸与宽高比的base_anchors
        base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
        feat_h, feat_w = featmap_size
        stride_w, stride_h = self.strides[level_idx]
 
        # 以下大多为数组形状的改变,如果嫌麻烦可以直接看图中最后的输出形状即可
 
        # 遍历特征图上所有位置,乘上 stride,变成原图下x,y坐标
        # 例如在二维坐标下,stride = 8的特征图上的[0,0]对应原图的[0,0],特征图上的[1,1]对应原图的[8,8]
        shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
        shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h
 
        shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
        # shifts是base_anchors在原图尺寸上的中心点位置
        shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
        # first feat_w elements correspond to the first row of shifts
        # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
        # shifted anchors (K, A, 4), reshape to (K*A, 4)
        
        # 将base_anchors的位置信息与原图上的中心点位置相加,得到原图上anchor的位置
        all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
        all_anchors = all_anchors.view(-1, 4)
 
        return all_anchors

7.valid_flags与single_level_valid_flags方法

由于grid_priors方法会产生一些位置在边界甚至超出边界的anchor_box,因此要对这类anchor_box进行区分,于是有了这个方法。

def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
        """判断每一个level下的特征图的anchor_box是否在有效位置上,若是则Ture1;否则赋为FALSE0
        """
        assert self.num_levels == len(featmap_sizes)
        multi_level_flags = []
        for i in range(self.num_levels):        # 遍历每层特征图
            anchor_stride = self.strides[i]
            feat_h, feat_w = featmap_sizes[i]
            h, w = pad_shape[:2]				# 有效的h、w
            # 获取有效的宽和高
            valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h)  
            valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w)
            flags = self.single_level_valid_flags((feat_h, feat_w),
                                                  (valid_feat_h, valid_feat_w),
                                                  self.num_base_anchors[i],  # 9个
                                                  device=device)	# flags为一个shape为H * W * num_base_anchors的Tensor[bool]
            multi_level_flags.append(flags)
        return multi_level_flags	# multi_level_flags为一个list,其中包含了num_levels个shape为H * W * num_base_anchors的Tensor[bool]

    def single_level_valid_flags(self,
                                 featmap_size,
                                 valid_size,
                                 num_base_anchors,
                                 device='cuda'):
        feat_h, feat_w = featmap_size
        valid_h, valid_w = valid_size
        assert valid_h <= feat_h and valid_w <= feat_w
        valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)  # 赋值为FALSE
        valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
        valid_x[:valid_w] = 1
        valid_y[:valid_h] = 1
        valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
        valid = valid_xx & valid_yy
        valid = valid[:, None].expand(valid.size(0),
                                      num_base_anchors).contiguous().view(-1)
        # 返回一个shape为H * W * num_base_anchors的Tensor[bool],来保存每一个单元格中的一个anchor_box是否有效,有效则为1
        return valid

总结

本文仅代表个人理解,若有不足,欢迎批评指正。
参考:MMDet逐行解读之AnchorGenerator_武乐乐~的博客-CSDN博客

你可能感兴趣的:(MMDet,Note,python,深度学习,人工智能)