CornerNet源码中对_decode函数(最后的输出处理)的理解

文章: https://arxiv.org/pdf/1808.01244.pdf
源码链接: https://github.com/princeton-vl/CornerNet.git, 基于Pytorch实现
_decode函数位于: CornerNet/models/py_utils/kp_utils.py中

以往的目标检测框架Faster RCNN, YOLO, SSD之类的都需要在网络中生成anchor boxes, 然后最后对matched boxes执行坐标回归(Smooth L1 Loss)和目标类别预测(softmax loss)
然而CornerNet文章中说了这种基于anchor boxes做法的缺点:
1. 需要产生大量的anchor boxes, 并且对这些boxes做预测, 但是只有一小部分的boxes才匹配ground-truth, 造成了imbalance between positive and negative anchor boxes, 并且slows down training
2. 使用anchor boxes需要设定很多超参数, 比如多少个boxes, 多大, 纵横比多少(aspect ratio). 这使得网络设计变得更复杂.

ConerNet抛弃了anchor boxes, 转而预测输入图片中的目标框的角点(corners). 网络最后输出两种角点预测图, 分别是左上角和右下角(top-left corner和bottom-right corner)角点的评分图. 如果某一对角点属于一个目标框, 则根据其在预测图中的位置映射回原图, 得到预测的目标框.

CornerNet源码中对_decode函数(最后的输出处理)的理解_第1张图片



文章中CornerNet的关键模块就是最后的Prediction Module, 这个模块最后有3个输出, 分别是Heatmaps, EmbeddingsOffsets.
本文尝试通过源码的_decode函数来理解这3个输出在测试时是怎么用的


一. 首先找出代码中模型的位置
在train.py中, nnet来自: CornerNet/nnet/py_factory.py中的class NetworkFactory(object):
在py_factory.py中, NetworkFactory又通过importlib.import_module(module_file)导入了self.model和self.loss, 
(module_file使用的system_configs.snapshot_name来自train.py中的configs["system"]["snapshot_name"] = args.cfg_file)
NetworkFactory中的self.model和self.loss, 这二者来自CornerNet/models/CornerNet.py中的class model(kp), 这个model继承自CornerNet/models/py_utils/kp.py中的class kp(nn.Module), 这个loss也是来自kp.py中的class AELoss(nn.Module)


所以model主要框架都在这个class kp(nn.Module)里, 只传入1张图片的list时, 模型执行_test函数. 所以在测试的时候(看CornerNet/test/coco.py中的def kp_decode函数), 输入被封装为[images](只有images这个元素)


二. 关于nstack(整个模型其实在实现中堆叠了nstack次)
class kp(nn.Module)定义时, 模型堆叠了nstack次, 具体地说nstack被设置为2
(注意, 这里说的这个堆叠了nstack次跟文章说的 
"Our hourglass network consists of two hourglass"不一样, 

这里说的相当于[ input -> ... -> [two hourglass] -> out # 第1次
                           input -> ... -> [two hourglass] -> out # 第2次]
)
模型在前向时, 也计算了nstack个的输出
(
看_train或者_test函数
首先对输入预处理: inter = self.pre(image)
第1次:
kp  = kp_(inter)
cnv = cnv_(kp)
...
outs += [] #加入到outs中
然后
if ind < self.nstack - 1:
    inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
    inter = self.relu(inter)
    inter = self.inters[ind](inter)

第2次:
kp  = kp_(inter)
cnv = cnv_(kp)
...
outs += [] #加入到outs中
注意第1次末尾对inter做了处理, 然后第2次再用

其中
tl_heat, br_heat就是top-left和bottom-right的Heatmaps
tl_tag,  br_tag就是top-left和bottom-right的Embeddings
tl_regr, br_regr就是top-left和bottom-right的Offsets
)
为啥要计算2个的输出呢??????看代码好像_train中使用到了这两个输出, 而_test只用到了第二个输出


三. _decode函数解析
在def _test函数中 (CornerNet/models/py_utils/kp.py), 模型也是计算了2个的输出, 但是从最后的
    return self._decode(*outs[-6:], **kwargs)
可以看到, 只用了第二次的输出作为预测
下面给出_decode函数的解析:

def _decode(
    tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr, 
    K=100, kernel=1, ae_threshold=1, num_dets=1000
):
    # tl_heat, br_heat: [bs, num_classes, height, width] tl点和br点的特征图(sigmoid后就是评分图), 这两个特征图的通道就表示类别, 通道数就是类别数
    # tl_tag,  br_tag:  [bs, 1, height, width] tl点和br点的embedding(每个点对应1维的值)
    # tl_regr, br_regt: [bs, 2, height, width] tl点和br点的offset(每个点对应2维的值, 分别表示x和y的offset)
    print("\n")
    print("in _decode")
    print("tl_heat shape:{}, br_heat.shape:{}".format(tl_heat.shape, br_heat.shape))
    print("tl_tag shape:{}, br_tag.shape:{}".format(tl_tag.shape, br_tag.shape))
    print("tl_regr shape:{}, br_regr.shape:{}".format(tl_regr.shape, br_regr.shape))
    batch, cat, height, width = tl_heat.size()

    tl_heat = torch.sigmoid(tl_heat)
    br_heat = torch.sigmoid(br_heat)

    # perform nms on heatmaps
    tl_heat = _nms(tl_heat, kernel=kernel) # max pooling ???
    br_heat = _nms(br_heat, kernel=kernel)

    ###############在评分图上找出topk个tl点和br点, 返回点的评分, 点的索引, 类别索引, 以及坐标x y的索引
    tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K) # [bs, K]
    br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)

    tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K) # 从列向量横向扩展, 扩展是为了方便tl和br点两两组合, 下面也是一样
    tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K) # [bs, K, K]
    br_ys = br_ys.view(batch, 1, K).expand(batch, K, K) # 从行向量纵向扩展
    br_xs = br_xs.view(batch, 1, K).expand(batch, K, K)

    if tl_regr is not None and br_regr is not None:
        ###############通过点的索引, 在regr中gather到topk点的offset(偏置), 并加到坐标中
        # 得到topk tl像素点的offsets, 每个像素点的offset是2维, 表示x和y的offset
        tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds) # [bs, K, 2]
        tl_regr = tl_regr.view(batch, K, 1, 2)
        # 得到topk br像素点的offsets
        br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
        br_regr = br_regr.view(batch, 1, K, 2)

        tl_xs = tl_xs + tl_regr[..., 0]
        tl_ys = tl_ys + tl_regr[..., 1]
        br_xs = br_xs + br_regr[..., 0]
        br_ys = br_ys + br_regr[..., 1]

    # all possible boxes based on top k corners (ignoring class)
    ###############通过expand为[bs, K, K]大小后, 将topk的左上角tl点 和 右下角br点 两两组合, 得到所有可能的box(一个样本一共有k*k个)
    bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3) # [bs, K, K, 4]
    ###############计算tl点和br点的embedding的绝对值
    tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds) # 得到topk tl像素点的embedding
    tl_tag = tl_tag.view(batch, K, 1)
    br_tag = _tranpose_and_gather_feat(br_tag, br_inds) # 得到topk br像素点的embedding
    br_tag = br_tag.view(batch, 1, K)
    # 将topk的tl点和br点两两相减, 取得绝对值
    dists  = torch.abs(tl_tag - br_tag) # [bs, K, K]
    ###############计算box的置信度
    tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K)
    br_scores = br_scores.view(batch, 1, K).expand(batch, K, K)
    # 将topk的tl点的置信度 和 br点的置信度相加取平均, 作为所有可能box的置信度?
    # 此时score包含对k*k个box的置信度
    scores    = (tl_scores + br_scores) / 2 # [bs, K, K]
    ###############剔除不属于同一个类别的tl点和br点
    # reject boxes based on classes
    tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
    br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
    # topk点中, 找出bl点和br点处于不同通道的 点的 索引
    # 每个通道表示一个类别!!!如果一对tl点和br点不在同个通道, 表示他们不是属于同类物体, 就拒绝这样的点
    # 这样就默认tl点和br点的topk的通道索引必须一样!!!
    cls_inds = (tl_clses != br_clses)

    # reject boxes based on distances
    ###############tl和br点的绝对值距离超过阈值, 就拒绝这样的点
    dist_inds = (dists > ae_threshold)

    # reject boxes based on widths and heights
    ###############topk点中, br点坐标不在tl点右下角, 就拒绝这样的点
    width_inds  = (br_xs < tl_xs)
    height_inds = (br_ys < tl_ys)

    scores[cls_inds]    = -1
    scores[dist_inds]   = -1
    scores[width_inds]  = -1
    scores[height_inds] = -1

    scores = scores.view(batch, -1)
    scores, inds = torch.topk(scores, num_dets) # 再选择前num_dets个置信度
    scores = scores.unsqueeze(2) # [bs, num_dets, 1]

    bboxes = bboxes.view(batch, -1, 4)
    bboxes = _gather_feat(bboxes, inds) # 前num_dets个box

    clses  = tl_clses.contiguous().view(batch, -1, 1) #所有可能box(k*k个)的通道索引(类别索引)
    clses  = _gather_feat(clses, inds).float() # 前num_dets个box的通道索引(类别索引)

    tl_scores = tl_scores.contiguous().view(batch, -1, 1)
    tl_scores = _gather_feat(tl_scores, inds).float() # 前num_dets个box的tl点评分
    br_scores = br_scores.contiguous().view(batch, -1, 1)
    br_scores = _gather_feat(br_scores, inds).float() # 前num_dets个box的br点评分

    detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2)
    return detections


以上就是_decode函数的解析

你可能感兴趣的:(python,pytorch,深度学习,目标检测)