CornerNet和loss、解码相关的函数其实在kp.py和kp_utils.py里面
解码函数如下所示:
def _decode(
tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr,
K=100, kernel=1, ae_threshold=1, num_dets=1000
):
batch, cat, height, width = tl_heat.size()
tl_heat = torch.sigmoid(tl_heat)
br_heat = torch.sigmoid(br_heat)
# perform nms on heatmaps
"""
其实就是对概率图进行maxpooling
"""
tl_heat = _nms(tl_heat, kernel=kernel)
br_heat = _nms(br_heat, kernel=kernel)
tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
"""
tl_ys,tl_xs原本的shape为[batch,K]
"""
tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K)
tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K)
br_ys = br_ys.view(batch, 1, K).expand(batch, K, K)
br_xs = br_xs.view(batch, 1, K).expand(batch, K, K)
if tl_regr is not None and br_regr is not None:
tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
tl_regr = tl_regr.view(batch, K, 1, 2)
br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
br_regr = br_regr.view(batch, 1, K, 2)
tl_xs = tl_xs + tl_regr[..., 0]
tl_ys = tl_ys + tl_regr[..., 1]
br_xs = br_xs + br_regr[..., 0]
br_ys = br_ys + br_regr[..., 1]
# all possible boxes based on top k corners (ignoring class)
bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
tl_tag = tl_tag.view(batch, K, 1)
br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
br_tag = br_tag.view(batch, 1, K)
"""
[k,1] - [1,k]隐式的都会扩张为[K,K]再相减
dists也为[K,K]
"""
dists = torch.abs(tl_tag - br_tag)
tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K)
br_scores = br_scores.view(batch, 1, K).expand(batch, K, K)
scores = (tl_scores + br_scores) / 2
# reject boxes based on classes
tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
cls_inds = (tl_clses != br_clses)
# reject boxes based on distances
dist_inds = (dists > ae_threshold)
# reject boxes based on widths and heights
width_inds = (br_xs < tl_xs)
height_inds = (br_ys < tl_ys)
scores[cls_inds] = -1
scores[dist_inds] = -1
scores[width_inds] = -1
scores[height_inds] = -1
scores = scores.view(batch, -1)
scores, inds = torch.topk(scores, num_dets)
scores = scores.unsqueeze(2)
"""
100*100的点最终匹配为10000个box,然后用类,距离,左上角和右下角的相对位置过滤掉n个点,过滤就体现在scores变为-1,然后传递到inds
最终通过 bboxes = _gather_feat(bboxes, inds)实现在box上的过滤
"""
bboxes = bboxes.view(batch, -1, 4)
bboxes = _gather_feat(bboxes, inds)
clses = tl_clses.contiguous().view(batch, -1, 1)
clses = _gather_feat(clses, inds).float()
tl_scores = tl_scores.contiguous().view(batch, -1, 1)
tl_scores = _gather_feat(tl_scores, inds).float()
br_scores = br_scores.contiguous().view(batch, -1, 1)
br_scores = _gather_feat(br_scores, inds).float()
detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2)
return detections
一、理论回顾
CornerNet的总共上下两个分支,每个分支三种输出,分别为:各点为角点的概率,偏移,表示两个点是否为同一个物体的embedding得分
左上角分支:
heat map:tl_heat [batch,C,H,W]
offset: tl_regr [batch,2,H,W]
embedding: tl_tag [batch,1,H,W]
右下角分支:
br_heat, tl_heat [batch,C,H,W]
br_regr [batch,2,H,W]
br_tag [batch,1,H,W]
总计tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr 六个输出。
解码流程:
1.分别在左上角和右下角的heat map中选出得分最高的前100个点
2.对这100个点进行逐个匹配,生成100*100个候选框
3. 通过匹配的两个点是否为同一类,两个点的embedding得分,以及空间位置(左上角必须比右下角小)这几个条件过滤掉大多数bbox,最终留下1000个候选框输出
注:最终生成的1000的候选框应该是要进行nms处理的,但是作者并未将nms操作写入解码函数