首先是示例代码:
def multibox_prior(data, sizes, ratios):
"""生成以每个像素为中心具有不同形状的锚框"""
in_height, in_width = data.shape[-2:]
device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
boxes_per_pixel = (num_sizes + num_ratios - 1)
size_tensor = torch.tensor(sizes, device=device) # 存放scale的tensor
ratio_tensor = torch.tensor(ratios, device=device) # 存放宽高比的tensor
# 为了将锚点移动到像素的中心,需要设置偏移量。
# 因为一个像素的高为1且宽为1,我们选择偏移我们的中心0.5
offset_h, offset_w = 0.5, 0.5
steps_h = 1.0 / in_height # 在y轴上缩放步长
steps_w = 1.0 / in_width # 在x轴上缩放步长
# 生成锚框的所有中心点
center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij') #torch.meshgrid生成网格,之后(shift_y[i],shift_x[i])就是一对可选参数
shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
# 生成“boxes_per_pixel”个高和宽,
# 之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax)
w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),
sizes[0] * torch.sqrt(ratio_tensor[1:])))\
* in_height / in_width # 处理矩形输入
h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
sizes[0] / torch.sqrt(ratio_tensor[1:])))
# w和h分别是anchor box的宽和高
# 除以2来获得半高和半宽
anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(
in_height * in_width, 1) / 2
# 每个中心点都将有“boxes_per_pixel”个锚框,
# 所以生成含所有锚框中心的网格,重复了“boxes_per_pixel”次
out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
dim=1).repeat_interleave(boxes_per_pixel, dim=0)
print(anchor_manipulations)
output = out_grid + anchor_manipulations
return output.unsqueeze(0)
首先是开始的几行:
in_height, in_width = data.shape[-2:]
device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
boxes_per_pixel = (num_sizes + num_ratios - 1)
size_tensor = torch.tensor(sizes, device=device) # 存放scale的tensor
ratio_tensor = torch.tensor(ratios, device=device) # 存放宽高比的tensor
这几行获取了输入图片的高度和宽度,设置了设备、不同缩放比的个数、不同宽高比的个数和每个像素的锚框数量。
offset_h, offset_w = 0.5, 0.5
steps_h = 1.0 / in_height # 在y轴上缩放步长
steps_w = 1.0 / in_width # 在x轴上缩放步长
center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')
shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
center_h和center_w表示的是中心点的横坐标和纵坐标(此处是百分比,也就是说都是在0~1之间的值)
接下来就是torch.meshgrid函数,该函数的作用是生成网格,可以用于生成坐标。函数输入两个数据类型相同的一维张量,两个输出张量的行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数,当两个输入张量数据类型不同或维度不是一维时会报错。
示例如下:
t1 = torch.tensor([1,2,3])
t2 = torch.tensor([2,3,4])
torch.meshgrid(t1,t2, indexing='ij')
输出:
(tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]]),
tensor([[2, 3, 4],
[2, 3, 4],
[2, 3, 4]]))
使用torch.meshgrid生成网格后,将 s h i f t y shift_y shifty和 s h i f t x shift_x shiftx均拉为1维,这样对于每一个i, s h i f t y [ i ] shift_y[i] shifty[i]和 s h i f t x [ i ] shift_x[i] shiftx[i]就是一个锚框中心点的坐标了。
接下来是困扰我许久的代码:
w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),
sizes[0] * torch.sqrt(ratio_tensor[1:])))\
* in_height / in_width # 处理矩形输入
h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
sizes[0] / torch.sqrt(ratio_tensor[1:])))
关键在于给出的公式和代码实现不一样,网上的解释是: r r r并不是锚框的宽高比,而是锚框的宽高比与图像的宽高比之比:
r 是指锚框的宽高比与图像的宽高比之比即 w’/h’ = w/h*r,s是图像尺寸缩放因子即w’h’ = whs^2,联立求解即可得文中的锚框宽高即w’ = ws×sqrt( r ), h’ = hs/sqrt( r )
anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(
in_height * in_width, 1) / 2
out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
dim=1).repeat_interleave(boxes_per_pixel, dim=0)
print(anchor_manipulations)
output = out_grid + anchor_manipulations
接下来的代码中,anchor_manipulations的生成使用到了stack函数,在未指定维度时默认dim = 0,在此处可以理解为在垂直方向上堆叠行向量。
示例如下:
t1 = torch.tensor([1,2,3])
t2 = torch.tensor([2,3,4])
torch.stack((t1,t2))
输出:
tensor([[1, 2, 3],
[2, 3, 4]])
repeat函数在此处是沿着列的方向重复这个张量。
示例如下:
t1 = torch.tensor([1,2,3])
t2 = torch.tensor([2,3,4])
t = torch.stack((t1,t2))
t.repeat(3,1)
输出:
tensor([[1, 2, 3],
[2, 3, 4],
[1, 2, 3],
[2, 3, 4],
[1, 2, 3],
[2, 3, 4]])
anchor_manipulations最终是一个大小为([图片中的像素点数*每个像素点为中心的锚框数, 4]),其中的每一行都为(-半宽,-半高,半宽,半高)
接下来out_grid的每一行都是(中心点x坐标,中心点y坐标,中心点x坐标,中心点y坐标),每一个这样的行都会重复 每个像素点为中心的锚框数 (次),只需要将out_grid 和anchor_manipulations相加,就可以得到每一个锚框的左上和右下的x坐标。