【vision transformer】DETR原理及代码详解(一)

DETR: End-to-End Object Detection with Transformers

论文: https://arxiv.org/pdf/2005.12872.pdf

代码: https://github.com/facebookresearch/detr (pytorch)

https://github.com/BR-IDL/PaddleViT/tree/develop/object_detection/DETR(PaddlePaddle)

1. DETR 概述

DETR 是vision transformer 中目标检测的开山之作,是 Facebook 团队于 2020 年提出的基于 Transformer 的端到端目标检测,克服了传统目标检测的anchore机制和非极大值抑制 NMS ,大大简化了目标检测的 pipeline。

相比faster rcnn等做法,detr最大特点是将目标检测问题转化为无序集合预测问题。faster rcnn设置一大堆anchor,然后基于anchor进行分类和回归其实属于代理做法。目标检测任务就是输出无序集合,而faster rcnn等算法通过各种操作,并结合复杂后处理最终才得到无序集合,而detr则相对直接。

【vision transformer】DETR原理及代码详解(一)_第1张图片

图1 DETR 的主要部件

【vision transformer】DETR原理及代码详解(一)_第2张图片

图2 DETR流程

给定一张图片,经CNN backbone 提取深度feature,然后转为特征序列输入到transformer的encode-decode,结果直接输出指定长度为N的无序的预测集合,集合中每个元素包含预测物体的cls类别和bbox坐标。

CNN输出的结果是一个 H*W*C的tensor,代表了图片提取出的feature map。将feature map变成( H ∗ W ) ∗ C的二维矩阵放到transformer中。

模型输出的结果是固定的,也就是说最多检测一张图片中N个目标。其中N表示整个数据集中图片上最多物体的数目,因为整个训练和测试都Batch进行,如果不设置最大输出集合数,无法进行batch训练;如果图片中物体不够N个,那么就采用no object填充,表示该元素是背景。

set prediction

目标检测的新范式,输入一副图像,网络的输出就是最终的预测的集合,也不需要任何后处理能够直接得到预测的集合。对于每一个pred,找到对应的GT,然后每个(Pred,GT)求loss,再进行训练。

bipartite matching loss :

假设我们现在有两个sets,左边的sets是模型预测得到的N 个元素,每个元素里有一个bbox和对这个bbox预测的类别的概率分布,预测的类别可以是空,用ϕ 来表示;右边的sets是我们的ground truth,每个元素里有一个标签的类别和对应的bbox,如果标签的数量不足N 则用ϕ 来补充,,ϕ 可以认为是background。

两边sets的元素数量都是N ,所以我们是可以做一个配对的操作,让左边的元素都能找到右边的一个配对元素,每个左边元素找到的右边元素都是不同的,也就是一一对应。这样的组合可以有N ! 种,所有组合记作σ N 。这个N 即是模型可以预测的最大数量。

【vision transformer】DETR原理及代码详解(一)_第3张图片

我们的目的是在这所有的N ! 种匹配中,找到使得L_{match}最小的那个组合,记作\hat{\sigma }

在分析loss计算前,需要先明确N个无序集合的target构建方式。detr输出是包括batchx100个无序集合,每个集合包括类别和坐标信息,其输出集合包括两个分支:分类分支shape=(b,100,92),bbox坐标分支shape=(b,100,4),对应的target也是包括分类target和bbox坐标target,如果不够100,则采用背景填充,计算loss时候bbox分支仅仅计算有物体位置,背景集合忽略。

问题:输出的bx100个检测结果是无序的,如何和gt bbox计算loss?

detr中利用匈牙利算法(双边匹配算法)先进行最优一对一匹配得到匹配索引,然后对bx100个结果进行重排就和gt bbox对应上,从而计算loss。

【vision transformer】DETR原理及代码详解(一)_第4张图片

优化对象是σ,其是长度为N 的list,σ(i)=i,表示无序gt bbox 集合的哪个元素和输出预测集合中的第i个匹配。即寻找最优匹配,因为在最佳匹配情况下l_match和最小即loss最小。

该函数核心是需要输入A集合和B集合两两元素之间的连接权重,基于该重要性进行内部最优匹配,连接权重大的优先匹配。

  2. DETR详细框架及流程

【vision transformer】DETR原理及代码详解(一)_第5张图片

图3 DETR 整体框架

输入:

【vision transformer】DETR原理及代码详解(一)_第6张图片

  图5 DETR 的输入

图像经backbone后的feature,flatten或reshape 成visual tokens,此处的CNN相当于vision transformer中 的patch embedding。

Encode:

【vision transformer】DETR原理及代码详解(一)_第7张图片

【vision transformer】DETR原理及代码详解(一)_第8张图片

 图6  DETR 的encode 部分

 类似VIT 的多层encoder堆叠,输入Visual Token+ position embedding(其中v值没有加position embedding,而且spatial position encode 在每一次的encode 中都参与计算,为了不断强化patch 的位置),输出为与Visual Token维度一致的特征。

Decode:

【vision transformer】DETR原理及代码详解(一)_第9张图片

【vision transformer】DETR原理及代码详解(一)_第10张图片

   图6  DETR 的decode 部分

 Decode 的目的是为了输出SeqtoSeq的方式,与encode的区别:输入除了encode output,还加入了object queries,类似于mask,重新定义一个可学习的参数或者embedding,不再用encode的输入visual token。处理过程加入了multi_head decode_encode attention,encode输出加入V,K,object queries 加入查询Q值。

object queries类似anchor(非几何上的),是可学习的特征向量。

Encode-decode:

【vision transformer】DETR原理及代码详解(一)_第11张图片

Decode 的注意力机制中加入了encode的输出。 

 DETR 的输出:

【vision transformer】DETR原理及代码详解(一)_第12张图片

 MLP即FC层,bbox四个坐标,output layer channel=4,output class = num_classes+1。

logist:(batch_size,num_queries,num_classes+1)

pred_boxes:(batch_size,num_queries,4)

 DETR的整体流程:

【vision transformer】DETR原理及代码详解(一)_第13张图片

position encoding 可采用可学习的方式,也可采用人为设定的(sin or cos),其持续性的加入层运算。

你可能感兴趣的:(vision,transformer,transformer,computer,vision)