mmdetection框架解读——anchor生成机制

anchor是一组先验框,在faster rcnn中被提出,再次记录mmdetection anchor生成思路。

mmdet中的AnchorGenerator类负责生成anchor,该类构造函数接收三个参数:base size,ratios,scales

base size:anchor大小
ratios:anchor 高宽比
scales:anchor缩放比例

每一处生成ratios*scales个anchor

在构造函数调用gen_base_anchors方法生成基础anchor,即特征图最左上角的anchors。特征图上其余地方的anchor只要在基础anchor的基础上加上平移参数就行了。

基础anchor生成过程:

1、每一处有不同宽高比的anchor,依据ratios对标准anchor宽和高缩放,使得高宽比为ratios并且面积不变。可以做如下转变
mmdetection框架解读——anchor生成机制_第1张图片
算出了所有基础anchor的宽高,再依据base size的anchor计算出的中心坐标即可算出基础anchor的左上角与右下角的坐标。

在mmdet中实现代码如下

def gen_base_anchors(self):
    w = self.base_size
    h = self.base_size
    #  base anchor 中心坐标
    if self.ctr is None:
        x_ctr = 0.5 * (w - 1)
        y_ctr = 0.5 * (h - 1)
    else:
        x_ctr, y_ctr = self.ctr
# 计算宽高相对于base anchor的缩放比例,使得缩放后的anchor面积不变,但是高宽比为ratio
h_ratios = torch.sqrt(self.ratios)
w_ratios = 1 / h_ratios
# 广播操作
if self.scale_major:
    ws = (w * w_ratios[:, None] * self.scales[None, :]).view(-1)
    hs = (h * h_ratios[:, None] * self.scales[None, :]).view(-1)
else:
    ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1)
    hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1)


# yapf: disable
base_anchors = torch.stack(
    [
        x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
        x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
    ],
    dim=-1).round()
# yapf: enable


return base_anchors

特征图所有anchor生成过程:

在生成base anchors之后只需要计算出整张特征图上每个位置处相比于base anchors 的偏移即可得出整个特征图的anchor坐标

def grid_anchors(self, featmap_size, stride=16, device=‘cuda’):
base_anchors = self.base_anchors.to(device)

feat_h, feat_w = featmap_size
# shift_x为横轴偏移,stride表示每一个特征图上的anchor相对于原图跳跃步伐,可参考下面案例理解stride
shift_x = torch.arange(0, feat_w, device=device) * stride
shift_y = torch.arange(0, feat_h, device=device) * stride

# _meshgrid该函数的作用就是计算出整个特征图上所有位置的x,y方向偏移
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
# 拼接得出特征图所有anchor两对角四个坐标的偏移
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
shifts = shifts.type_as(base_anchors)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)

# base_anchors是一个(A,4)的tensor,shift是一个(k,4)的tensor,利用广播机制得出每个位置的anchor
# 得出的anchor坐标分别为对角的坐标,即(xmin,ymin,xmax,ymax)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4)
# first A rows correspond to A anchors of (0, 0) in feature map,
# then (0, 1), (0, 2), ...
return all_anchors

案例
base size = 9,ratios=1,scales=1
feature map 2×2,默认anchor stride=16

“”"
Examples:
>>> from mmdet.core import AnchorGenerator
>>> self = AnchorGenerator(9, [1.], [1.])
>>> all_anchors = self.grid_anchors((2, 2), device=‘cpu’)
>>> print(all_anchors)
tensor([[ 0., 0., 8., 8.],
[16., 0., 24., 8.],
[ 0., 16., 8., 24.],
[16., 16., 24., 24.]])
“”"

参考: https://mingming97.github.io/2019/03/26/anchor%20in%20object%20detection/

你可能感兴趣的:(目标检测)