mmdetection/mmdet/core/anchor/anchor_generator.py中AnchorGenerator类的关键代码解读。
AnchorGenerator类主要目的是为了生成anchor-base Detector所需要的anchor_box。
通过【gen_base_anchors】方法生成单个anchor点的9种(3种尺寸、3种宽高比)基anchor_box,调用【grid_priors】方法将这9种基anchor_box再原图的尺寸上进行广播,得到一个list,其中包括所有原图尺寸上的anchor_box位置信息(左上角坐标与右下角坐标)。
@PRIOR_GENERATORS.register_module()
class AnchorGenerator:
def __init__(self,
strides, # 例:[8, 16, 32, 64, 128]
ratios, # 例:anchor的三种宽高比[0.5, 1.0, 2.0]
scales=None,
base_sizes=None,
scale_major=True,
octave_base_scale=None, # 例:4
scales_per_octave=None, # 例:3
centers=None,
center_offset=0.):
# check center and center_offset
if center_offset != 0:
assert centers is None, 'center cannot be set when center_offset' \
f'!=0, {centers} is given.'
if not (0 <= center_offset <= 1):
raise ValueError('center_offset should be in range [0, 1], '
f'{center_offset} is given.')
if centers is not None:
assert len(centers) == len(strides), \
'The number of strides should be the same as centers, got ' \
f'{strides} and {centers}'
# self.strides = [(8,8),(16,16),(32,32),(64,64),(128,128)]
self.strides = [_pair(stride) for stride in strides]
# base_sizes = [8, 16, 32, 64, 128]
self.base_sizes = [min(stride) for stride in self.strides
] if base_sizes is None else base_sizes
assert len(self.base_sizes) == len(self.strides), \
'The number of strides should be the same as base sizes, got ' \
f'{self.strides} and {self.base_sizes}'
# octave_base_scale、scales_per_octave这两个参数和scales不能共存
assert ((octave_base_scale is not None
and scales_per_octave is not None) ^ (scales is not None)), \
'scales and octave_base_scale with scales_per_octave cannot' \
' be set at the same time'
if scales is not None:
self.scales = torch.Tensor(scales)
# 通过octave_base_scale与scales_per_octave自动计算得到scale
# self.scales = octave_base_scale * [2^0, 2^(1/3), 2^(2/3)] = [4,5,6]
elif octave_base_scale is not None and scales_per_octave is not None:
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
scales = octave_scales * octave_base_scale
self.scales = torch.Tensor(scales)
else:
raise ValueError('Either scales or octave_base_scale with '
'scales_per_octave should be set')
# 最终的值
self.octave_base_scale = octave_base_scale # 4
self.scales_per_octave = scales_per_octave # 3
self.ratios = torch.Tensor(ratios) # [0.5, 1, 2]
self.scale_major = scale_major # True
self.centers = centers # None
self.center_offset = center_offset # 0
self.base_anchors = self.gen_base_anchors()
# self.scales = [4,5,6]
# self.strides = [(8,8),(16,16),(32,32),(64,64),(128,128)]
@property
def num_base_anchors(self):
"""list[int]: total number of base anchors in a feature grid"""
return self.num_base_priors
@property
def num_base_priors(self): # 每一个level下的base_anchors数量
"""list[int]: The number of priors (anchors) at a point
on the feature grid"""
return [base_anchors.size(0) for base_anchors in self.base_anchors]
@property
def num_levels(self): # level的个数
"""int: number of feature levels that the generator will be applied"""
return len(self.strides)
def gen_base_anchors(self):
"""
产生base_anchors,也就是单个anchor上的9种(例)不同尺寸与宽高比的anchor_box
Returns:
list(torch.Tensor): 每一个特征图尺寸下的基anchor_box组成的list # len(list) = len(self.stride)
"""
multi_level_base_anchors = [] # 存储每个特征尺度下的base_anchors
for i, base_size in enumerate(self.base_sizes): # 在每个特征尺度下生成base_anchors
center = None
if self.centers is not None:
center = self.centers[i]
multi_level_base_anchors.append(
# 调用gen_single_level_base_anchors方法,产生当前特征尺度下的base_anchors
self.gen_single_level_base_anchors(
base_size, # 8 / 16 / 32 /64 /128(for循环变量)
scales=self.scales, # [4,5,6]
ratios=self.ratios, # [0.5,1,2]
center=center)) # None
return multi_level_base_anchors
# multi_level_base_anchors = [[stride1_base_anchors], [stride2_base_anchors], ...]
def gen_single_level_base_anchors(self,
base_size, # 8 (以8为例)
scales, # [4,5,6]
ratios, # [0.5,1,2]
center=None):
"""Generate base anchors of a single level.
"""
w = base_size # w = 8
h = base_size # h = 8
if center is None:
x_center = self.center_offset * w # 0
y_center = self.center_offset * h # 0
else:
x_center, y_center = center
# h_ratios:w_ratios = [0.5:1, 1:1, 2:1]
h_ratios = torch.sqrt(ratios)
w_ratios = 1 / h_ratios
if self.scale_major: # self.scale_major = True
# 由以下公式计算得到9个ws与9个hs
ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
else:
ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
# [center_x,center_y,w,h] --> [xmin, ymin, xmax,ymax]
base_anchors = [
x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
y_center + 0.5 * hs
]
# 使用torch.stack改变下形状
base_anchors = torch.stack(base_anchors, dim=-1)
return base_anchors
"""以2个anchor为例
base_size, scales, ratios = 8, [4,6], 1
w, h, h_ratios, w_ratios = 8, 8, 1, 1
ws = 8 * 1 * [4, 6] = [32, 48]
hs = 8 * 1 * [4, 6] = [32, 48]
base_anchors = [[-16, -24],
[-16, -24],
[16, 24],
[16, 24]]
# torch.stack之后
base_anchors = [[-16., -16., 16., 16.],
[-24., -24., 24., 24.]]
"""
该方法与【gen_base_anchors方法】的生成方式类似,区别是:1、该方法是后期调用使用的,而【gen_base_anchors方法】是在生成AnchorGenerator类时自动调用的。2、返回的列表内容不同,该方法返回每个特征图上相对于原图的所有anchor_box的位置。
def grid_priors(self, featmap_sizes, dtype=torch.float32, device='cuda'):
"""Generate grid anchors in multiple feature levels.
"""
assert self.num_levels == len(featmap_sizes)
multi_level_anchors = []
for i in range(self.num_levels):
anchors = self.single_level_grid_priors(
featmap_sizes[i], level_idx=i, dtype=dtype, device=device)
multi_level_anchors.append(anchors)
return multi_level_anchors
# multi_level_anchors = [[level1_anchorboxs], [level2_anchorboxs], ...]
def single_level_grid_priors(self,
featmap_size,
level_idx,
dtype=torch.float32,
device='cuda'):
"""Generate grid anchors of a single level.
也就是生成每一个特征图尺度下所有的anchor_boxs,其坐标是相对于原图尺寸的
"""
# 得到当前level_idx下9个(例)不同尺寸与宽高比的base_anchors
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]
# 以下大多为数组形状的改变,如果嫌麻烦可以直接看图中最后的输出形状即可
# 遍历特征图上所有位置,乘上 stride,变成原图下x,y坐标
# 例如在二维坐标下,stride = 8的特征图上的[0,0]对应原图的[0,0],特征图上的[1,1]对应原图的[8,8]
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
# shifts是base_anchors在原图尺寸上的中心点位置
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
# 将base_anchors的位置信息与原图上的中心点位置相加,得到原图上anchor的位置
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4)
return all_anchors
由于grid_priors方法
会产生一些位置在边界甚至超出边界的anchor_box,因此要对这类anchor_box进行区分,于是有了这个方法。
def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
"""判断每一个level下的特征图的anchor_box是否在有效位置上,若是则Ture1;否则赋为FALSE0
"""
assert self.num_levels == len(featmap_sizes)
multi_level_flags = []
for i in range(self.num_levels): # 遍历每层特征图
anchor_stride = self.strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = pad_shape[:2] # 有效的h、w
# 获取有效的宽和高
valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w)
flags = self.single_level_valid_flags((feat_h, feat_w),
(valid_feat_h, valid_feat_w),
self.num_base_anchors[i], # 9个
device=device) # flags为一个shape为H * W * num_base_anchors的Tensor[bool]
multi_level_flags.append(flags)
return multi_level_flags # multi_level_flags为一个list,其中包含了num_levels个shape为H * W * num_base_anchors的Tensor[bool]
def single_level_valid_flags(self,
featmap_size,
valid_size,
num_base_anchors,
device='cuda'):
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) # 赋值为FALSE
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
valid = valid[:, None].expand(valid.size(0),
num_base_anchors).contiguous().view(-1)
# 返回一个shape为H * W * num_base_anchors的Tensor[bool],来保存每一个单元格中的一个anchor_box是否有效,有效则为1
return valid
本文仅代表个人理解,若有不足,欢迎批评指正。
参考:MMDet逐行解读之AnchorGenerator_武乐乐~的博客-CSDN博客