看代码解读CenterNet :Objects as Points

文章目录

    • 1. idea
    • 2. 网络结果
      • 2.1 hm分支
      • 2.2 wh分支
      • 2.3 reg分支
    • 3. 数据+loss
    • 4. 推理
    • 结果
    • reference

摘要:
centernet是单阶段的目标检测网络,采用高斯图来表示目标,速度很快,比YOLOV3更快,效果方面没有对比过。网上也没有具体的对比。

这篇通过读代码的方式来解读centernet,先来张图:
看代码解读CenterNet :Objects as Points_第1张图片

1. idea

文章主要用高斯分布来表示目标,就是一个目标用高斯分布来覆盖,目标中心点的值越大。

2. 网络结果

2.1 hm分支

看代码解读CenterNet :Objects as Points_第2张图片
整个网络主要有bone net特征提取网络和输出部分组成,网络结果如上图所示,特征提取网络就不细讲了。用高斯分布来表示目标,网络第三个分支/hm输出部分网络如下:

nn.Sequential(    nn.Conv2d(64, 256, kernel_size=3, padding=1, bias=True),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(256, classes, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))

其实就是一个conv2d(64,256),relu(),conv2d(256,1),最后的输出为n_category×128×128,一个类别一个通道,其中每个点的值表示:是目标的概率有多大。上图是只有一个类别的情况。

2.2 wh分支

网络定义为:

nn.Sequential(    nn.Conv2d(64, 256, kernel_size=3, padding=1, bias=True),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(256, 2, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))

分支最后输出:2×128×128,所有类别用共同的预测宽度w和高度h。

2.3 reg分支

网络定义为:

nn.Sequential(    nn.Conv2d(64, 256, kernel_size=3, padding=1, bias=True),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(256, 2, kernel_size=final_kernel, stride=1, padding=final_kernel // 2, bias=True))

分支输出:2×128×128,每个点的两个值表示,当前index为目标时hm输出位置的偏差,所有类别用共同的w,h预测值。

总结: 整个网络还是很好理解的。


3. 数据+loss

  1. 分类误差
    采用的focalloss,target是一个n_cat×128×128的矩阵,每一个channel表示一个类别,128×128的每个值是用高斯分布覆盖目标处理出来的。
  2. 尺寸误差
    target是:(2×n_category)×128×128矩阵,用L1_loss。
  3. 中心点修正误差
    target是:(2×n_category)×128×128矩阵,用L1_loss。

4. 推理

  1. 网络前向传播
output = self.model(images)[-1]
hm = output['hm'].sigmoid_()
wh = output['wh']
reg = output['reg']

就是正常的前向,hm的输出要经过sigmoid。

  1. 对hm用max_pool2d(3*3)进行滤波
hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
  1. 根据hm的值筛选出前100
scores, inds, clses, ys, xs = _topk(heat, K=K)
  1. 对结果进行一些后处理,比如中点+reg等

结果

看代码解读CenterNet :Objects as Points_第3张图片

reference

  1. 【论文笔记】CenterNet:Objects as Points
  2. centernet源码

你可能感兴趣的:(目标检测)