inference
1. sigmoid 出左上和右下角点的heatmap
tl_heat = torch.sigmoid(tl_heat) br_heat = torch.sigmoid(br_heat)
2. 然后进行一个NMS 也就是最大池化
tl_heat = _nms(tl_heat, kernel=kernel)
br_heat = _nms(br_heat, kernel=kernel)
def _nms(heat, kernel=1):
pad = (kernel - 1) // 2
hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float() return heat * keep
3. 选择出左上和右下角点的topK
tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=100)
br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=100)
def _topk(scores, K=20):
batch, cat, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
topk_clses = (topk_inds / (height * width)).int()
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs
4. 加上回归后的偏移量
if tl_regr is not None and br_regr is not None:
tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
tl_regr = tl_regr.view(batch, K, 1, 2)
br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
br_regr = br_regr.view(batch, 1, K, 2)
tl_xs = tl_xs + tl_regr[..., 0]
tl_ys = tl_ys + tl_regr[..., 1]
br_xs = br_xs + br_regr[..., 0]
br_ys = br_ys + br_regr[..., 1]
5. 对左上和右下的角点进行匹配,100*100=10000个proposal,计算embeeding后特征的距离
bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
tl_tag = tl_tag.view(batch, K, 1)
br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
br_tag = br_tag.view(batch, 1, K)
dists = torch.abs(tl_tag - br_tag)
6. 用类别过滤掉一些错误的proposal
tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
cls_inds = (tl_clses != br_clses)
7. 用embeeding的距离过滤掉一些proposal
dist_inds = (dists > ae_threshold)
8. 通过左上角点的坐标小于右下角点的坐标,过滤掉一些proposal
width_inds = (br_xs < tl_xs) height_inds = (br_ys < tl_ys)
最终得到了box