ECCV2020
Facebook AI
第一个用Transformer进行目标检测的网络
论文: https://arxiv.org/abs/2005.12872
代码: https://github.com/facebookresearch/detr
搭建参考: 【CODE】Facebook 最新DETR(基于Transformer)目标检测算法实战_哔哩哔哩_bilibili
DETR通过将常见的CNN与Transformer架构相结合,不仅做到了并行化,不需要受到之前结果的影响;
而且把目标检测任务变成了一个Set Prediction任务,一口气预测一个集合,而不是像RNN一样一个一个预测。
DETR使用常规的CNN主干来学习输入图像的2D表示,展平后进行位置编码的补充,随后传递到transformer编码器中。
然后,transformer解码器将少量固定数量的可学习位置嵌入作为输入,称其为object queries。
将解码器的每个输出嵌入传递到预测检测(类和边界框)或“无对象”类的共享前馈网络(FFN)。
整个网络分为几个主要模块:backbone -> encoder -> decoder -> prediction heads
输入图像3*H0*W0,传统的CNN主干会生成较低分辨率的激活图。DETR采用的典型值为C=2048,H=H0/32,W=W0/32。
首先利用1*1卷积降低通道数(C->d),创建一个新的特征图z0∈R d×H×W。 编码器期望一个序列作为输入,因此DETR将z0的空间维度折叠为一个维度,从而生成d×HW特征图。 每个编码器层均具有标准的结构,包括一个multi-head self-attention module和一个前馈网络(FFN)。 DETR用固定的位置编码对其进行补充,该编码被添加到每个自注意力层的输入中。
解码器遵循transformer的标准结构,使用multi-head self- and encoder-decoder attention 处理大小为d的N个embeddings。与原始Transformer的不同之处在于,DETR模型在每个解码器层并行解码N个目标,而原始Transformer则是使用自回归模型,一次预测一个元素的输出序列。
object queries是可学习的位置编码,被添加到每个自注意力层的输入中。 N=100个object queries由解码器转换为一个output embedding。 它们通过FFN独立地解码为边界框坐标和类标签,从而得到N个框的位置和类别分数。
最终预测是由3层的FFN计算得到的,FFN使用ReLU激活,隐藏维度为d。
FFN根据输入图像得到标准化中心坐标、高度和宽度, 然后线性层使用softmax函数预测类标签。
DETR预测了一组固定大小的N=100个边界框,这个数量大于图片中感兴趣的目标。那如何把预测结果和ground truth相对应起来呢?
因此在计算损失时,第一步是将ground truth也扩展成100个检测框,使用一个额外的特殊类标签∅来表示未检测到任何对象,或者认为是背景类别。
然后采用匈牙利算法进行二分图匹配,对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。
计算Ground truth yi和预测出来的第σi个结果之间的匹配损失:对于不是背景的,获得其对应预测是目标类别的概率,然后用框损失减去预测类别的概率。
经过匈牙利算法后,可以得到groud truth和预测目标框之间的一一对应关系。
损失函数和匹配损失不同之处在于损失函数是正值,所以使用log-probability。
对于ci=∅的类别损失,将分类损失除以10来降低作用。目标边界框回归损失是IOU损失和L1损失的加权和,其中IOU损失对于scale不敏感,L1损失对于scale敏感。实际上DETR用的是GLOU损失。
DETR论文提出了两种编码方式:spatial positional encoding和object queries。
spatial positional encoding 既输入到encoder也输入到decoder(包括learned和sin两种)。
object queries在初始时是N个随机向量,将这些随机向量编码与图像特征相结合。
可以理解为这些queries去图像信息中查询,得到检测框和类别预测。
相当于可学习的anchor,从而解决手动设计anchor的问题。
在这些object queries的作用下,预测出来的框如下图。
每个图都是一个object queries在COCO 2017 val set预测出来的框的结果。
图中每个点都是一个框的中心点,绿色表示小框,蓝色表示纵向大框,红色表示横向大框。
可见每个queries都有自己的特点,关注某个特定的区域。
在本文代码中object queries是大小为100*2*256的变量,通过训练确定。
这里的100是超参数,对于每张图片预测100个bboxes。
选择100的原因是数据集中有90个类别,100刚好合适。