详解目标检测之anchor-box生成

无论前端backbone如何,总会输出一个feature map 如何在这个feature map上得到pred的 box呢?

第一步先得到 feature map上所有的 中心点x,y

举例:
以特征图大小(4, 4 ,5, 6)表示,4为batch,4 为channel, 5为h, 6为w

#以feature的大小5*6 生成单位网格
import torch
yy, xx =torch.meshgrid(torch.arange(5), torch.arange(6)
print(yy,xx)
#得到如下
tensor([[0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4]]) 
tensor([[0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5]])
#堆叠起来
mesh = torch.stack([xx, yy], dim=0)
print(mesh)
#输出如下,shape为[2, 5, 6]
tensor([[[0, 1, 2, 3, 4, 5],
         [0, 1, 2, 3, 4, 5],
         [0, 1, 2, 3, 4, 5],
         [0, 1, 2, 3, 4, 5],
         [0, 1, 2, 3, 4, 5]],

        [[0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4, 4]]])
 #增加一个batch 位置的维度   shape为  [4, 2, 5, 6] 
mesh = mesh.unsqueeze(0).repeat(4,1,1,1).float() 
print(mesh)
tensor([[[[0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.]],

         [[0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4., 4.]]],


        [[[0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.]],

         [[0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4., 4.]]],


        [[[0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.]],

         [[0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4., 4.]]],


        [[[0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.],
          [0., 1., 2., 3., 4., 5.]],

         [[0., 0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4., 4.]]]])

#### 然后 以预设定的anchor box 比率生成

anchor_wh=[
[ 5.0174, 15.0521],
[ 7.0833, 21.2500],
[10.0347, 24.7917],
[20.0694, 18.8889]]

anchor_offset_mesh = anchor_wh.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, nGh,nGw)
print(anchor_offset_mesh)
tensor([[[[ 5.0174,  5.0174,  5.0174,  5.0174,  5.0174,  5.0174],
          [ 5.0174,  5.0174,  5.0174,  5.0174,  5.0174,  5.0174],
          [ 5.0174,  5.0174,  5.0174,  5.0174,  5.0174,  5.0174],
          [ 5.0174,  5.0174,  5.0174,  5.0174,  5.0174,  5.0174],
          [ 5.0174,  5.0174,  5.0174,  5.0174,  5.0174,  5.0174]],

         [[15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521],
          [15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521],
          [15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521],
          [15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521],
          [15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521]]],


        [[[ 7.0833,  7.0833,  7.0833,  7.0833,  7.0833,  7.0833],
          [ 7.0833,  7.0833,  7.0833,  7.0833,  7.0833,  7.0833],
          [ 7.0833,  7.0833,  7.0833,  7.0833,  7.0833,  7.0833],
          [ 7.0833,  7.0833,  7.0833,  7.0833,  7.0833,  7.0833],
          [ 7.0833,  7.0833,  7.0833,  7.0833,  7.0833,  7.0833]],

         [[21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500],
          [21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500],
          [21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500],
          [21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500],
          [21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500]]],


        [[[10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347],
          [10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347],
          [10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347],
          [10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347],
          [10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347]],

         [[24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917],
          [24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917],
          [24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917],
          [24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917],
          [24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917]]],


        [[[20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694],
          [20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694],
          [20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694],
          [20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694],
          [20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694]],

         [[18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889],
          [18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889],
          [18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889],
          [18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889],
          [18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889]]]])

anchor_offset_mesh=torch.cat([mesh2, anchor_offset_mesh],dim=1)
f=anchor_offset_mesh.permute(0,2,3,1).contiguous().view(-1, 4) 
print(f)
# 生成 x,y, w,h 格式的bounding box
tensor([[ 0.0000,  0.0000,  5.0174, 15.0521],
        [ 1.0000,  0.0000,  5.0174, 15.0521],
        [ 2.0000,  0.0000,  5.0174, 15.0521],
        [ 3.0000,  0.0000,  5.0174, 15.0521],
        [ 4.0000,  0.0000,  5.0174, 15.0521],
        [ 5.0000,  0.0000,  5.0174, 15.0521],
        [ 0.0000,  1.0000,  5.0174, 15.0521],
        [ 1.0000,  1.0000,  5.0174, 15.0521],
        [ 2.0000,  1.0000,  5.0174, 15.0521],
        [ 3.0000,  1.0000,  5.0174, 15.0521],
        [ 4.0000,  1.0000,  5.0174, 15.0521],
        [ 5.0000,  1.0000,  5.0174, 15.0521],
        [ 0.0000,  2.0000,  5.0174, 15.0521],
        [ 1.0000,  2.0000,  5.0174, 15.0521],
        [ 2.0000,  2.0000,  5.0174, 15.0521],
        [ 3.0000,  2.0000,  5.0174, 15.0521],
        [ 4.0000,  2.0000,  5.0174, 15.0521],
        [ 5.0000,  2.0000,  5.0174, 15.0521],
        [ 0.0000,  3.0000,  5.0174, 15.0521],
        [ 1.0000,  3.0000,  5.0174, 15.0521],
        [ 2.0000,  3.0000,  5.0174, 15.0521],
        [ 3.0000,  3.0000,  5.0174, 15.0521],
        [ 4.0000,  3.0000,  5.0174, 15.0521],
        [ 5.0000,  3.0000,  5.0174, 15.0521],
        [ 0.0000,  4.0000,  5.0174, 15.0521],
        [ 1.0000,  4.0000,  5.0174, 15.0521],
        [ 2.0000,  4.0000,  5.0174, 15.0521],
        [ 3.0000,  4.0000,  5.0174, 15.0521],
        [ 4.0000,  4.0000,  5.0174, 15.0521],
        [ 5.0000,  4.0000,  5.0174, 15.0521],
        [ 0.0000,  0.0000,  7.0833, 21.2500],
        [ 1.0000,  0.0000,  7.0833, 21.2500],
        [ 2.0000,  0.0000,  7.0833, 21.2500],
        [ 3.0000,  0.0000,  7.0833, 21.2500],
        [ 4.0000,  0.0000,  7.0833, 21.2500],
        [ 5.0000,  0.0000,  7.0833, 21.2500],
        [ 0.0000,  1.0000,  7.0833, 21.2500],
        [ 1.0000,  1.0000,  7.0833, 21.2500],
        [ 2.0000,  1.0000,  7.0833, 21.2500],
        [ 3.0000,  1.0000,  7.0833, 21.2500],
        [ 4.0000,  1.0000,  7.0833, 21.2500],
        [ 5.0000,  1.0000,  7.0833, 21.2500],
        [ 0.0000,  2.0000,  7.0833, 21.2500],
        [ 1.0000,  2.0000,  7.0833, 21.2500],
        [ 2.0000,  2.0000,  7.0833, 21.2500],
        [ 3.0000,  2.0000,  7.0833, 21.2500],
        [ 4.0000,  2.0000,  7.0833, 21.2500],
        [ 5.0000,  2.0000,  7.0833, 21.2500],
        [ 0.0000,  3.0000,  7.0833, 21.2500],
        [ 1.0000,  3.0000,  7.0833, 21.2500],
        [ 2.0000,  3.0000,  7.0833, 21.2500],
        [ 3.0000,  3.0000,  7.0833, 21.2500],
        [ 4.0000,  3.0000,  7.0833, 21.2500],
        [ 5.0000,  3.0000,  7.0833, 21.2500],
        [ 0.0000,  4.0000,  7.0833, 21.2500],
        [ 1.0000,  4.0000,  7.0833, 21.2500],
        [ 2.0000,  4.0000,  7.0833, 21.2500],
        [ 3.0000,  4.0000,  7.0833, 21.2500],
        [ 4.0000,  4.0000,  7.0833, 21.2500],
        [ 5.0000,  4.0000,  7.0833, 21.2500],
        [ 0.0000,  0.0000, 10.0347, 24.7917],
        [ 1.0000,  0.0000, 10.0347, 24.7917],
        [ 2.0000,  0.0000, 10.0347, 24.7917],
        [ 3.0000,  0.0000, 10.0347, 24.7917],
        [ 4.0000,  0.0000, 10.0347, 24.7917],
        [ 5.0000,  0.0000, 10.0347, 24.7917],
        [ 0.0000,  1.0000, 10.0347, 24.7917],
        [ 1.0000,  1.0000, 10.0347, 24.7917],
        [ 2.0000,  1.0000, 10.0347, 24.7917],
        [ 3.0000,  1.0000, 10.0347, 24.7917],
        [ 4.0000,  1.0000, 10.0347, 24.7917],
        [ 5.0000,  1.0000, 10.0347, 24.7917],
        [ 0.0000,  2.0000, 10.0347, 24.7917],
        [ 1.0000,  2.0000, 10.0347, 24.7917],
        [ 2.0000,  2.0000, 10.0347, 24.7917],
        [ 3.0000,  2.0000, 10.0347, 24.7917],
        [ 4.0000,  2.0000, 10.0347, 24.7917],
        [ 5.0000,  2.0000, 10.0347, 24.7917],
        [ 0.0000,  3.0000, 10.0347, 24.7917],
        [ 1.0000,  3.0000, 10.0347, 24.7917],
        [ 2.0000,  3.0000, 10.0347, 24.7917],
        [ 3.0000,  3.0000, 10.0347, 24.7917],
        [ 4.0000,  3.0000, 10.0347, 24.7917],
        [ 5.0000,  3.0000, 10.0347, 24.7917],
        [ 0.0000,  4.0000, 10.0347, 24.7917],
        [ 1.0000,  4.0000, 10.0347, 24.7917],
        [ 2.0000,  4.0000, 10.0347, 24.7917],
        [ 3.0000,  4.0000, 10.0347, 24.7917],
        [ 4.0000,  4.0000, 10.0347, 24.7917],
        [ 5.0000,  4.0000, 10.0347, 24.7917],
        [ 0.0000,  0.0000, 20.0694, 18.8889],
        [ 1.0000,  0.0000, 20.0694, 18.8889],
        [ 2.0000,  0.0000, 20.0694, 18.8889],
        [ 3.0000,  0.0000, 20.0694, 18.8889],
        [ 4.0000,  0.0000, 20.0694, 18.8889],
        [ 5.0000,  0.0000, 20.0694, 18.8889],
        [ 0.0000,  1.0000, 20.0694, 18.8889],
        [ 1.0000,  1.0000, 20.0694, 18.8889],
        [ 2.0000,  1.0000, 20.0694, 18.8889],
        [ 3.0000,  1.0000, 20.0694, 18.8889],
        [ 4.0000,  1.0000, 20.0694, 18.8889],
        [ 5.0000,  1.0000, 20.0694, 18.8889],
        [ 0.0000,  2.0000, 20.0694, 18.8889],
        [ 1.0000,  2.0000, 20.0694, 18.8889],
        [ 2.0000,  2.0000, 20.0694, 18.8889],
        [ 3.0000,  2.0000, 20.0694, 18.8889],
        [ 4.0000,  2.0000, 20.0694, 18.8889],
        [ 5.0000,  2.0000, 20.0694, 18.8889],
        [ 0.0000,  3.0000, 20.0694, 18.8889],
        [ 1.0000,  3.0000, 20.0694, 18.8889],
        [ 2.0000,  3.0000, 20.0694, 18.8889],
        [ 3.0000,  3.0000, 20.0694, 18.8889],
        [ 4.0000,  3.0000, 20.0694, 18.8889],
        [ 5.0000,  3.0000, 20.0694, 18.8889],
        [ 0.0000,  4.0000, 20.0694, 18.8889],
        [ 1.0000,  4.0000, 20.0694, 18.8889],
        [ 2.0000,  4.0000, 20.0694, 18.8889],
        [ 3.0000,  4.0000, 20.0694, 18.8889],
        [ 4.0000,  4.0000, 20.0694, 18.8889],
        [ 5.0000,  4.0000, 20.0694, 18.8889]])

你可能感兴趣的:(深度学习)