centernet代码阅读笔记

Objects as Points

文章: https://arxiv.org/abs/1904.07850
代码: centernet

原理请看扔掉anchor!真正的CenterNet——Objects as Points论文解读
尊重原创,请读原文

1. 网络结构

1.1 主干网络

1.2 输出部分

  1. net.reg,???
Sequential(
  (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace)
  (2): Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
)

输出为:batch×2×128×128
2. net.wh,目标检测中的偏移量

Sequential(
  (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace)
  (2): Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
)

输出为:batch×2×128×128
3. net.hm,热力图的输出,就是目标的中心位置

Sequential(
  (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace)
  (2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
)

输出为:batch×1×128×128

1.3 求loss

  1. 输入
loss, loss_stats = self.loss(outputs, batch)          # 输入
网络输出
outputs = {'hm', 'reg', 'wh'}
[batch,80,128,128][batch,2,128,128][batch,2,128,128],也就是每个坐标点产生C+4个数据,分别是类别以及、长宽、以及偏置。

目标值:
batch = dict_keys(['input', 'hm', 'reg', 'ind', 'wh', 'reg_mask'])
hm:batch*1*128*128,热力图的目标值——热力图损失,只有一个类,本人有修改源数据
reg:batch*50*2
ind:batch*50,目标中心点在128×128特征图中的索引
wh:batch*50*2,目标矩形框的宽高——目标尺寸损失
reg_mask:batch*50,有目标的位置的mask

50:应该是一个限制数,最多一张图片中50个目标,少于50个则补0

  1. 热力图损失——目标中心定位
    主要是目标定位用,这是一个二分类问题,项目主要是用的focal loss
hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
  1. WH损失——目标大小的回归
    主要是做目标大小的wh的回归,用L1loss
    流程:提取ind位置的wh,与目标wh做L1loss
wh_loss += self.crit_reg(
            output['wh'], batch['reg_mask'],
            batch['ind'], batch['wh']) / opt.num_stacks
  1. reg损失——目标中心的偏置
    参考自扔掉anchor!真正的CenterNet——Objects as Points论文解读
    尊重原创,请读原文
    目标中心的偏置损失
    因为上文中对图像进行了R=4R=4的下采样,这样的特征图重新映射到原始图像上的时候会带来精度误差,因此对于每一个中心点,额外采用了一个local offset去补偿它。所有类c的中心点共享同一个offset prediction,这个偏置值(offset)用L1 loss来训练:
    但是在推断过程中,我们首先读入图像[640,320],然后变形成[512,512],然后下采样4倍成[128,128]。最终预测使用的图像大小是[128,128],而每个预测出来的热点中心(headmap center),假设我们预测出与实际标记的中心点[98.97667,2.3566666]对应的点是[98,2],坐标是(x,y),对应的类别是c,等同于这个点上hm =1,有物体存在,但是我们标记出的点是[98,2],直接映射为[512,512]的形式肯定会有精度损失,为了解决这个就引入了Loff 偏置损失。
    和上面的wh一样采用同样的L1loss
off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
                             batch['ind'], batch['reg']) / opt.num_stacks

推理阶段

具体看
扔掉anchor!真正的CenterNet——Objects as Points论文解读

你可能感兴趣的:(目标检测)