文章: https://arxiv.org/pdf/1808.01244.pdf
源码链接: https://github.com/princeton-vl/CornerNet.git, 基于Pytorch实现
_decode函数位于: CornerNet/models/py_utils/kp_utils.py中
以往的目标检测框架Faster RCNN, YOLO, SSD之类的都需要在网络中生成anchor boxes, 然后最后对matched boxes执行坐标回归(Smooth L1 Loss)和目标类别预测(softmax loss)
然而CornerNet文章中说了这种基于anchor boxes做法的缺点:
1. 需要产生大量的anchor boxes, 并且对这些boxes做预测, 但是只有一小部分的boxes才匹配ground-truth, 造成了imbalance between positive and negative anchor boxes, 并且slows down training
2. 使用anchor boxes需要设定很多超参数, 比如多少个boxes, 多大, 纵横比多少(aspect ratio). 这使得网络设计变得更复杂.
而ConerNet抛弃了anchor boxes, 转而预测输入图片中的目标框的角点(corners). 网络最后输出两种角点预测图, 分别是左上角和右下角(top-left corner和bottom-right corner)角点的评分图. 如果某一对角点属于一个目标框, 则根据其在预测图中的位置映射回原图, 得到预测的目标框.
文章中CornerNet的关键模块就是最后的Prediction Module, 这个模块最后有3个输出, 分别是Heatmaps, Embeddings和Offsets.
本文尝试通过源码的_decode函数来理解这3个输出在测试时是怎么用的
一. 首先找出代码中模型的位置
在train.py中, nnet来自: CornerNet/nnet/py_factory.py中的class NetworkFactory(object):
在py_factory.py中, NetworkFactory又通过importlib.import_module(module_file)导入了self.model和self.loss,
(module_file使用的system_configs.snapshot_name来自train.py中的configs["system"]["snapshot_name"] = args.cfg_file)
NetworkFactory中的self.model和self.loss, 这二者来自CornerNet/models/CornerNet.py中的class model(kp), 这个model继承自CornerNet/models/py_utils/kp.py中的class kp(nn.Module), 这个loss也是来自kp.py中的class AELoss(nn.Module)
所以model主要框架都在这个class kp(nn.Module)里, 只传入1张图片的list时, 模型执行_test函数. 所以在测试的时候(看CornerNet/test/coco.py中的def kp_decode函数), 输入被封装为[images](只有images这个元素)
二. 关于nstack(整个模型其实在实现中堆叠了nstack次)
class kp(nn.Module)定义时, 模型堆叠了nstack次, 具体地说nstack被设置为2
(注意, 这里说的这个堆叠了nstack次跟文章说的
"Our hourglass network consists of two hourglass"不一样,
这里说的相当于[ input -> ... -> [two hourglass] -> out # 第1次
input -> ... -> [two hourglass] -> out # 第2次]
)
模型在前向时, 也计算了nstack个的输出
(
看_train或者_test函数
首先对输入预处理: inter = self.pre(image)
第1次:
kp = kp_(inter)
cnv = cnv_(kp)
...
outs += [] #加入到outs中
然后
if ind < self.nstack - 1:
inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
inter = self.relu(inter)
inter = self.inters[ind](inter)
第2次:
kp = kp_(inter)
cnv = cnv_(kp)
...
outs += [] #加入到outs中
注意第1次末尾对inter做了处理, 然后第2次再用
其中
tl_heat, br_heat就是top-left和bottom-right的Heatmaps
tl_tag, br_tag就是top-left和bottom-right的Embeddings
tl_regr, br_regr就是top-left和bottom-right的Offsets
)
为啥要计算2个的输出呢??????看代码好像_train中使用到了这两个输出, 而_test只用到了第二个输出
三. _decode函数解析
在def _test函数中 (CornerNet/models/py_utils/kp.py), 模型也是计算了2个的输出, 但是从最后的
return self._decode(*outs[-6:], **kwargs)
可以看到, 只用了第二次的输出作为预测
下面给出_decode函数的解析:
def _decode(
tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr,
K=100, kernel=1, ae_threshold=1, num_dets=1000
):
# tl_heat, br_heat: [bs, num_classes, height, width] tl点和br点的特征图(sigmoid后就是评分图), 这两个特征图的通道就表示类别, 通道数就是类别数
# tl_tag, br_tag: [bs, 1, height, width] tl点和br点的embedding(每个点对应1维的值)
# tl_regr, br_regt: [bs, 2, height, width] tl点和br点的offset(每个点对应2维的值, 分别表示x和y的offset)
print("\n")
print("in _decode")
print("tl_heat shape:{}, br_heat.shape:{}".format(tl_heat.shape, br_heat.shape))
print("tl_tag shape:{}, br_tag.shape:{}".format(tl_tag.shape, br_tag.shape))
print("tl_regr shape:{}, br_regr.shape:{}".format(tl_regr.shape, br_regr.shape))
batch, cat, height, width = tl_heat.size()
tl_heat = torch.sigmoid(tl_heat)
br_heat = torch.sigmoid(br_heat)
# perform nms on heatmaps
tl_heat = _nms(tl_heat, kernel=kernel) # max pooling ???
br_heat = _nms(br_heat, kernel=kernel)
###############在评分图上找出topk个tl点和br点, 返回点的评分, 点的索引, 类别索引, 以及坐标x y的索引
tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K) # [bs, K]
br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K) # 从列向量横向扩展, 扩展是为了方便tl和br点两两组合, 下面也是一样
tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K) # [bs, K, K]
br_ys = br_ys.view(batch, 1, K).expand(batch, K, K) # 从行向量纵向扩展
br_xs = br_xs.view(batch, 1, K).expand(batch, K, K)
if tl_regr is not None and br_regr is not None:
###############通过点的索引, 在regr中gather到topk点的offset(偏置), 并加到坐标中
# 得到topk tl像素点的offsets, 每个像素点的offset是2维, 表示x和y的offset
tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds) # [bs, K, 2]
tl_regr = tl_regr.view(batch, K, 1, 2)
# 得到topk br像素点的offsets
br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
br_regr = br_regr.view(batch, 1, K, 2)
tl_xs = tl_xs + tl_regr[..., 0]
tl_ys = tl_ys + tl_regr[..., 1]
br_xs = br_xs + br_regr[..., 0]
br_ys = br_ys + br_regr[..., 1]
# all possible boxes based on top k corners (ignoring class)
###############通过expand为[bs, K, K]大小后, 将topk的左上角tl点 和 右下角br点 两两组合, 得到所有可能的box(一个样本一共有k*k个)
bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3) # [bs, K, K, 4]
###############计算tl点和br点的embedding的绝对值
tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds) # 得到topk tl像素点的embedding
tl_tag = tl_tag.view(batch, K, 1)
br_tag = _tranpose_and_gather_feat(br_tag, br_inds) # 得到topk br像素点的embedding
br_tag = br_tag.view(batch, 1, K)
# 将topk的tl点和br点两两相减, 取得绝对值
dists = torch.abs(tl_tag - br_tag) # [bs, K, K]
###############计算box的置信度
tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K)
br_scores = br_scores.view(batch, 1, K).expand(batch, K, K)
# 将topk的tl点的置信度 和 br点的置信度相加取平均, 作为所有可能box的置信度?
# 此时score包含对k*k个box的置信度
scores = (tl_scores + br_scores) / 2 # [bs, K, K]
###############剔除不属于同一个类别的tl点和br点
# reject boxes based on classes
tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
# topk点中, 找出bl点和br点处于不同通道的 点的 索引
# 每个通道表示一个类别!!!如果一对tl点和br点不在同个通道, 表示他们不是属于同类物体, 就拒绝这样的点
# 这样就默认tl点和br点的topk的通道索引必须一样!!!
cls_inds = (tl_clses != br_clses)
# reject boxes based on distances
###############tl和br点的绝对值距离超过阈值, 就拒绝这样的点
dist_inds = (dists > ae_threshold)
# reject boxes based on widths and heights
###############topk点中, br点坐标不在tl点右下角, 就拒绝这样的点
width_inds = (br_xs < tl_xs)
height_inds = (br_ys < tl_ys)
scores[cls_inds] = -1
scores[dist_inds] = -1
scores[width_inds] = -1
scores[height_inds] = -1
scores = scores.view(batch, -1)
scores, inds = torch.topk(scores, num_dets) # 再选择前num_dets个置信度
scores = scores.unsqueeze(2) # [bs, num_dets, 1]
bboxes = bboxes.view(batch, -1, 4)
bboxes = _gather_feat(bboxes, inds) # 前num_dets个box
clses = tl_clses.contiguous().view(batch, -1, 1) #所有可能box(k*k个)的通道索引(类别索引)
clses = _gather_feat(clses, inds).float() # 前num_dets个box的通道索引(类别索引)
tl_scores = tl_scores.contiguous().view(batch, -1, 1)
tl_scores = _gather_feat(tl_scores, inds).float() # 前num_dets个box的tl点评分
br_scores = br_scores.contiguous().view(batch, -1, 1)
br_scores = _gather_feat(br_scores, inds).float() # 前num_dets个box的br点评分
detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2)
return detections
以上就是_decode函数的解析