CenterNet: Keypoint Triplets for Object Detection 论文学习

论文地址:https://arxiv.org/pdf/1904.07850.pdf
代码地址:https://github.com/xingyizhou/CenterNet

Abstract

在目标检测中,由于缺乏对相关剪裁区域的额外监督,基于关键点的方法通常会得到一大堆错误的物体边框。本文提出了一个有效的方法,在每个裁剪区域内以最小的代价去探索它的视觉模式。我们构建了一个单阶段基于关键点的检测器,叫做CornerNet。CornerNet 用每个目标物体的三个关键点来检测,而不是一对关键点,提升识别精度和召回率。因此,本文设计了两个模块,分别是 cascade corner pooling 和 center pooling,前者能丰富左上角和右下角搜集到的信息,后者在中间区域能提供更具辨识度的信息。在MS-COCO数据集上,CenterNet 获得的AP值是 47 47% 47,比所有的单阶段检测器至少高出 4.9 4.9% 4.9。同时,它的前向推理速度更快,CenterNet 的性能和双阶段检测器相比也很具竞争力。

Introduction

深度学习出现之后,目标检测得到了明显的提升。目前最流行的方法都是基于 anchor 的,在要识别物体上放置预先定义好的 anchor boxes,通过ground truth boxes 回归出相应的位置。这些方法通常需要一堆 anchors 来保证预测的边框和 ground truth 有较高的 IoU,anchors 的大小,宽高比都需要提前人为设计好。此外,anchors 经常会和 ground truth 边框不一致,降低边框分类的准确率。

为了解决 anchor 的缺点,人们提出了一个基于关键点的目标检测方法,CornerNet。它用一对角点来表示每个物体,无需 anchor boxes,在单阶段检测器中取得了 state of art 的检测准确率。但是,CornerNet 仍有局限性,就是它缺乏对物体全局信息的参考。也就是说,由于每个物体都是用两个角点表示,算法对识别物体的边界框很敏感,而同时又无法确定哪两个关键点属于同一个物体。因此,如图1所示,经常会产生一些错误的边框,绝大多数都可以很容易地通过辅助信息(如宽高比)去除。

CenterNet: Keypoint Triplets for Object Detection 论文学习_第1张图片

为了解决这个问题,我们让CornerNet 可以识别每个候选区域内的视觉模式,这样它就能自己识别每个边框的正确性。在这篇论文中,我们提出了一个低成本但是很高效的办法,叫做 CenterNet,通过增加一个关键点来探索候选框内中间区域(靠近几何中心的位置)的信息。我们的想法就是,如果一个预测边框和 ground truth 边框有着很高的 IoU,则该边框的中心关键点预测出相同类别的概率要高,反之亦然。所以,在推理时,通过一对关键点产生了一个边框,如果同类别物体的中心关键点落在该候选框的中心区域,那么我们就认为该候选框包含那个物体。如果目标边框是准确的,那么在其中心区域能够检测到目标物体中心点的概率就会很高。 若有则保留该目标框,若无则删除该目标框。如图1,即使用三个关键点来表示目标物体。

为了更好的检测中心关键点和角点,我们提出了两个方法来分别增强中心和角点信息。

  • 第一个方法叫center pooling,用于预测中心关键点的分支。Center pooling 有助于中心关键点取得物体内部辨识度更高的视觉信息,让候选框中心部分的感知更简单。实现方式是,在预测中心关键点的特征图上,取中心关键点横向和纵向上响应和的最大值。
  • 第二个方法就是cascade corner pooling,增加原始 corner pooling 感知候选框内部信息的功能。实现方式是,在预测角点的特征图上,计算物体边框和内部方向上响应和的最大值。实验证明,这样一个双指向的池化方法面对噪声更加稳定,鲁棒性更强,有助于提升精度和召回。

我们在MS-COCO 数据集上评估了CenterNet。在 center pooling 和 cascade corner pooling 都使用的情况下,在测试集上AP值能达到 47 47% 47,超过了现有的单阶段检测器一大截子。使用了52层的 Hourglass 主干网络时,推理时间平均为270毫秒每张图片;使用104层的Hourglass 主干网络时,推理时间为340毫秒每张图片。CenterNet 效率很高,和现有的双阶段检测器相比也不弱。

2. Related Work

目标检测涉及到物体的定位和分类。在深度学习纪元,受深度神经网络推动,目标检测方法大致可分为两类,双阶段方法和单阶段方法。

Two-stage approaches

双阶段方法将目标检测任务分为两步,提取RoI,然后对RoI进行回归和分类。

R-CNN 使用选择搜索方法,在输入图像上定位RoI,然后用深度卷积网络对每个RoI进行独立的分类。SPPNet 和 Fast-RCNN 改进了R-CNN,它们从特征图上提取RoI。Faster-RCNN 引入了RPN实现端到端的训练。RPN 通过回归 anchor boxes 产生RoI。随后,anchor boxes 被大范围应用在目标检测任务上。Mask-RCNN 在 Faster-RCNN 上增加了一个 mask 预测分支,可以同时检测物体和预测 mask。R-FCN 将全连接层替换为位敏得分图(position-sensitive score map) 提升检测效果。Cascade R-CNN 训练一组检测器,它们的阈值逐渐递增,以此解决训练过程中的过拟合问题和测试中的图像质量不匹配问题。我们提出基于关键点的物体检测,目的是避免 anchor boxes 和边框回归带来的问题。

One-stage approaches

单阶段方法移除了RoI提取过程,而直接地在候选 anchor boxes 上进行分类和回归。

YOLO 使用较少的 anchor boxes,将图片分割为 S × S S\times S S×S 的网格,进行回归和分类。YOLOv2 通过使用更多的 anchor boxes 和新的边框回归方法提升了性能。SSD 在输入图像上放置密集的检测框,并使用不同卷积层的特征对 anchor boxes 进行回归和分类。DSSD 往SSD中加入了反卷积模块,融合低级别和高级别特征。RefineDet 对 anchor boxes 的大小和位置进行了两次的优化,继承发扬了单阶段和双阶段的长处。CornerNet 是另一个基于关键点的方法,用一对角点直接在图像上进行预测。尽管CornerNet 的性能很高,它仍有很大提升的地方。

3. Approach

3.1 Baseline and Motivation

这篇论文使用 CornerNet 作为基线。为了检测角点,CornerNet 产生两个热力图:一个左上角的热力图,一个右下角的热力图。热力图代表不同类别关键点的位置,对每个关键点赋一个置信度分数。此外,CornerNet 也对每个角点预测一个 embedding 和一组偏移量。Embeddings 用于判断两个角点是否来自同一个目标物体。偏移量学习如何将角点从热力图重新映射回输入图像上,为了产生物体的边框,我们依据它们的分数从热力图上分别选取 t o p − k top-k topk 个左上角点和右下角点。然后,我们计算这一对角点的 embedding 向量的距离,以此来判断这一对角点是否属于同一个物体。如果距离小于某阈值,则会生成一个物体边框。该边框会得到一个置信度分数,等于这一对角点的平均分数。

CenterNet: Keypoint Triplets for Object Detection 论文学习_第2张图片

在表1中,我们提供了CornerNet 的深入的分析。我们在MS-COCO数据集上计算了 CornerNet 的 false discovery(FD) rate,即错误的检测边框的比例。结果显示在 IoU 阈值较低时,错误检测边框占了很大的比例,比如当 IoU 为0.05时,FD rate 是 32.7 32.7% 32.7。也就是平均下来,每一百个物体边框,有32.7个边框和 ground truth 边框的 IoU 是低于0.05的。小的错误边框就更多了,FD rate 是 60.3 60.3% 60.3。一个可能原因是,CornerNet 无法深入边框内部一窥究竟。为了让 CornerNet 感知边框内的视觉信息,一个方案就是将 CornerNet 改为双阶段检测器,使用 RoI 池化来深入了解边框内的视觉信息。但是,这种操作带来的计算成本很高。

在这篇论文,我们提出了一个非常有效的替代方案, CenterNet,可以发掘每个边框内的视觉信息。为了检测物体,我们的方法使用了一个三元组关键点,而非一对关键点。这样做后,我们的方法仍然是一个单阶段检测器,但是部分继承了 RoI 池化的功能。此方法仅关注中心位置信息,计算成本是很小的。同时通过 center pooling 和 cascade corner pooling,我们在关键点检测过程中进一步加入了物体内部的视觉信息。

3.2 Object Detection as Keypoint Triplets

CenterNet: Keypoint Triplets for Object Detection 论文学习_第3张图片

CenterNet 的整体结构如图2所示。我们用一个中心关键点和一对角点来表示每个物体。我们在 CornerNet 的基础上加入一个中心关键点的热力图,同时预测中心关键点的偏移。然后,基于 CornerNet 提出的方法产生 t o p − k top-k topk 个候选框。但是,为了剔除错误的边框,利用检测到的中心点位置,对其按如下过程进行排序操作:

  • 根据它们的分数,选择 t o p − k top-k topk个中心关键点;
  • 使用相应的偏移量将中心关键点重新映射回输入图像中;
  • 为每个边框定义一个中心区域,确保该中心区域存在中心关键点。同时确保该中心关键点的类别和边框的类别一致。
  • 如果在中心区域检测到中心关键点,我们就保留这个边框。用左上角,右下角和中心关键点分数的平均值更新边框的分数,并保存该边框。如果在该中心区域没有检测到中心关键点,则移除此边框。

边框内中心区域的大小会影响到检测的结果。例如,较小的中心区域会导致小边框的召回率低,较大的中心区域会导致大边框的精度低。因而,我们提出了一个可以自动适应边框大小的中心区域。该中心区域对小边框产生较大的中心区域,而对大边框产生较小的中心区域。假设我们要决定是否保留某边框 i i i t l x tl_x tlx t l y tl_y tly 代表边框 i i i的左上角坐标, b r x , b r y br_x, br_y brx,bry代表边框 i i i右下角的坐标。定义一个中心区域 j j j,让 c t l x , c t l y ctl_x, ctl_y ctlx,ctly 表示该中心区域 j j j左上角的坐标,让 c b r x , c b r y cbr_x, cbr_y cbrx,cbry 表示该中心区域 j j j 右下角的坐标。那么 t l x , t l y , b r x , b r y , c t l x , c t l y , c b r x , c b r y tl_x, tl_y, br_x, br_y, ctl_x, ctl_y, cbr_x, cbr_y tlx,tly,brx,bry,ctlx,ctly,cbrx,cbry满足如下关系:

{ c t l x = ( n + 1 ) t l x + ( n − 1 ) b r x 2 n c t l y = ( n + 1 ) t l y + ( n − 1 ) b r y 2 n c b r x = ( n − 1 ) t l x + ( n + 1 ) b r x 2 n c b r y = ( n − 1 ) t l y + ( n + 1 ) b r y 2 n \left\{ \begin{aligned} ctl_x & = & \frac{(n+1)tl_x + (n-1)br_x}{2n} \\ ctl_y & = & \frac{(n+1)tl_y + (n-1)br_y}{2n} \\ cbr_x & = & \frac{(n-1)tl_x + (n+1)br_x}{2n} \\ cbr_y & = & \frac{(n-1)tl_y + (n+1)br_y}{2n} \end{aligned} \right. ctlxctlycbrxcbry====2n(n+1)tlx+(n1)brx2n(n+1)tly+(n1)bry2n(n1)tlx+(n+1)brx2n(n1)tly+(n+1)bry

n n n是个奇数,决定中心区域 j j j的大小。在这篇论文中,当边框小于150时, n = 3 n=3 n=3,否则 n = 5 n=5 n=5。图3就是 n = 3 n=3 n=3 n = 5 n=5 n=5时两种中心区域。根据上面的等式,我们可以得到自适应的中心区域,然后在里面检测中心区域是否包含中心关键点。

CenterNet: Keypoint Triplets for Object Detection 论文学习_第4张图片

3.3 Enriching Center and Corner Information

Center Pooling

CenterNet: Keypoint Triplets for Object Detection 论文学习_第5张图片

物体几何中心的信息并不一定具有很高的辨识度,比如人体的头部具有很强的视觉模式,但是几何中心通常在人躯体的中间部位。为了解决这个问题,我们提出了 center pooling 来捕捉更加丰富,更具辨识度的视觉信息。图4(a)展示了 center pooling 的原理。Center pooling 的详细原理如下:首先主干网络输出一个特征图,为了判断特征图上是否存在一个中心关键点,我们需要找到它水平和垂直方向上的最大值,并将其相加。这样子,center pooling 就可以提升中心关键点的检测。

Cascade corner pooling

角点通常位于物体的边界,缺乏局部的外观特征。CornerNet 使用 corner pooling 解决了这个问题。Corner pooling 的原理如图4(b) 所示。Corner pooling 主要是寻找边界上的最大值,以此来确定角点。但是这样存在一个问题,就是角点对边框很敏感。为了解决这个问题,我们需要让角点看到物体内部的视觉模式。本文做了改进,Cascade corner pooling 的原理如图4© 所示。首先沿着边界寻找边界上的最大值,然后沿着最大值的位置往边框里面看,找到内部的最大响应值,比如上面的边,就往垂直于它的正下方看去。最后,将这两个最大值相加处理。这样,角点既可获得边界信息,也可获得物体内部的视觉信息。

通过在不同的方向上组合不同的 corner pooling,我们可以轻松实现 Center pooling 和 cascade corner pooling。二者模式如图5 所示。图5(a) 展示了 center pooling 模块的结构。Center pooling 为了获得水平方向上的最大值,依次顺序连接 left pooling 和 right pooling。图5(b) 展示了 cascade top corner pooling 模块。和传统 CornerNet 中的 top corner pooling相比,在 top corner pooling 前增加了一个 left corner pooling。

CenterNet: Keypoint Triplets for Object Detection 论文学习_第6张图片

3.4 训练和测试

训练

我们的模型用PyTorch 实现,网络从0开始训练。图像分辨率是 511 × 511 511\times 511 511×511,产生的热力图大小是 128 × 128 128\times 128 128×128。我们使用了数据增强策略让模型更鲁棒。同样使用了 Adam 方法来训练损失函数:

L = L d e t c o + L d e t c e + α L p u l l c o + β L p u s h c o + γ ( L o f f c o + L o f f c e ) . L = L^{co}_{det} + L^{ce}_{det} + \alpha L^{co}_{pull} + \beta L^{co}_{push} + \gamma (L^{co}_{off} + L^{ce}_{off}). L=Ldetco+Ldetce+αLpullco+βLpushco+γ(Loffco+Loffce).

L d e t c o , L d e t c e L^{co}_{det}, L^{ce}_{det} Ldetco,Ldetce 分别表示 focal loss,分别用于检测角点和中心关键点。 L p u l l c o L^{co}_{pull} Lpullco 是角点的 “pull” 损失,用于最小化同一物体的 embedding 向量之间的距离。 L p u s h c o L^{co}_{push} Lpushco 是角点的"push"损失,用于最大化不同物体的 embedding 向量之间的距离。 L o f f c o L^{co}_{off} Loffco L o f f c e L^{ce}_{off} Loffce l 1 l_1 l1损失,用于训练网络来分别预测角点和中心关键点的偏移量。 α , β , γ \alpha,\beta, \gamma α,β,γ 表示相应损失的权重,分别设为 0.1 , 0.1 , 1 0.1, 0.1, 1 0.1,0.1,1 L d e t , L p u l l , L p u s h , L o f f L_{det}, L_{pull}, L_{push}, L_{off} Ldet,Lpull,Lpush,Loff 都在 CornerNet 中定义了,请阅读该论文了解更多信息。我们在8块 Tesla V100 GPUs 上训练,batch size 设为48。Iterations 的最大个数设为48万。在前45万个iterations中,学习率设为 2.5 × 1 0 − 4 2.5\times 10^{-4} 2.5×104。最后3万个iterations,学习率为 2.5 × 1 0 − 5 2.5\times 10^{-5} 2.5×105

测试

接着CornerNet论文,对于单一尺度测试,我们输入原图和经过水平翻转的图片,分辨率与原始分辨率一样。对于多尺度测试,我们输入原图和水平翻转的图片,分辨率分别为 0.6 , 1.2 , 1.5 , 1.8 0.6, 1.2, 1.5, 1.8 0.6,1.2,1.5,1.8。我们从热力图上选取前70个中心关键点,前70个左上角点,前70个右下角点,以此检测边框。在水平翻转的图片上,我们翻转检测到的边框,将它们和原来的边框进行混合。我们也用了 Soft-NMS 来去除多余的边框。最终,根据得分,选择前100个边框作为最终的检测结果。

4. Experiments

Pls read paper for more details.

你可能感兴趣的:(深度学习,目标检测,图像识别)