文章: https://arxiv.org/abs/1904.07850
代码: centernet
原理请看扔掉anchor!真正的CenterNet——Objects as Points论文解读
尊重原创,请读原文
Sequential(
(0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
)
输出为:batch×2×128×128
2. net.wh,目标检测中的偏移量
Sequential(
(0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
)
输出为:batch×2×128×128
3. net.hm,热力图的输出,就是目标的中心位置
Sequential(
(0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
)
输出为:batch×1×128×128
loss, loss_stats = self.loss(outputs, batch) # 输入
网络输出
outputs = {'hm', 'reg', 'wh'}
[batch,80,128,128]、[batch,2,128,128]、[batch,2,128,128],也就是每个坐标点产生C+4个数据,分别是类别以及、长宽、以及偏置。
目标值:
batch = dict_keys(['input', 'hm', 'reg', 'ind', 'wh', 'reg_mask'])
hm:batch*1*128*128,热力图的目标值——热力图损失,只有一个类,本人有修改源数据
reg:batch*50*2
ind:batch*50,目标中心点在128×128特征图中的索引
wh:batch*50*2,目标矩形框的宽高——目标尺寸损失
reg_mask:batch*50,有目标的位置的mask
50:应该是一个限制数,最多一张图片中50个目标,少于50个则补0
hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
wh_loss += self.crit_reg(
output['wh'], batch['reg_mask'],
batch['ind'], batch['wh']) / opt.num_stacks
off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
batch['ind'], batch['reg']) / opt.num_stacks
具体看
扔掉anchor!真正的CenterNet——Objects as Points论文解读