源码解析目标检测的跨界之星DETR(一)、概述与模型推断

Date: 2020/06/27

Author: CW

前言:

阅读了 DETR 的论文后,近期梳理了相关代码,本系列会结合源码对 DETR 进行解析,包含模型效果的简单演示、训练的 pipeline、backbone、编码、解码、loss的设计与计算、后处理、评估验证的 pipeline。CW认为,认真阅读完本系列的每篇文后,将 DETR 的实现理解透彻是OK的,但要真正地吃透,还需要朋友你亲自实践并且深入思考。

本文作为系列的开篇之作,就简单一些吧,一上来就是复杂的源码分析难免有将客人拒之门外的赶脚。因此,CW在本文中会对这个模型做个简单的概述,然后基于官方给出的 notebook demo对模型推断部分的代码进行解析(注意,这个demo中模型的实现并不是 DETR 真正的实现方式,仅是个简化版)。

DETR: End-to-end Object Detection with Transformers

Code


Outline

I. 概述

II. 模型推断


概述

DETRDEtection TRansformer, 是 Facebook AI 研究院提出的 CV 模型,主要用于目标检测,也可以用于分割任务。该模型使用 Transformer 替代了复杂的目标检测传统套路,比如 two-stage 或 one-stage、anchor-based 或 anchor-free、nms 后处理等;也没有使用一些骚里骚气的技巧,比如在使用多尺度特征融合、使用一些特殊类型的卷积(如分组卷积、可变性卷积、动态生成卷积等)来抽取特征、对特征图作不同类型的映射以将分类与回归任务解耦、甚至是数据增强,整个过程就是使用CNN提取特征后编码解码得到预测输出

可以说,整体工作很solid,虽然效果未至于 SOTA,但将炼丹者们通常认为是属于 NLP 领域的 Transformer 拿来跨界到 CV 领域使用,并且能work,这是具有重大意义的,其中的思想也值得我们学习。这种突破传统与开创时代的工作往往是深得人心的,比如 Faster R-CNN 和 YOLO,你可以看到之后的许多工作都是在它们的基础上做改进的。

概括地说,DETR 将目标检测任务看作集合预测问题,对于一张图片,固定预测一定数量的物体(原作是100个,在代码中可更改),模型根据这些物体对象与图片中全局上下文的关系直接并行输出预测集,也就是 Transformer 一次性解码出图片中所有物体的预测结果,这种并行特性使得 DETR 非常高效。

DETR 框架

模型推断

这个demo会基于预训练权重实现一个DETR的简化版,然后对一张图片作预测,最后展示出预测效果。

首先导入需要的相关库:

导入相关库

然后,实现一个简化版的模型:

模型定义(i) 初始化

模型主要由 backbone、transformer 以及 最后形成预测输出的线性层构成,另外,还需要一个卷积层将backbone输出的特征图维度映射到transformer输入所需的维度。

了解 Transformer 的朋友们应该知道,其本身是不了解输入序列中各部分的位置关系的,因此通常需要加入位置编码,此处也一样:

模型定义(i) 初始化

上图中,行列编码的第一个维度都是50,代表这里默认backbone输出的特征图尺寸不超过50x50。

模型的初始化方法就到此结束了,是那么得丝滑..额不对,是那么得简洁明了,接下来看看模型的前向过程:


模型定义(ii). 前向过程

上图中的部分是将图片输入到backbone提取特征,然后对输出特征图维度进行转换,并且构造位置编码张量。这里位置编码张量的实现是对特征图的行、列分别进行编码后拼接起来,同时进行维度转换以适应编码器的输入。

下面就是将以上部分输入到 Transformer 进行编码与解码,最后将解码的结果输入到线性层形成最终的预测结果:


模型定义(ii). 前向过程

注意下,上图中对 Transformer 的输出维度顺序做了调整,因此最后得到的h的维度是(batch, 100, hidden_dim)。

整个前向过程也就这样了,是不是感觉让你撸起代码来毫无压力,嘿嘿!

下面是对输入图片和输出bbox的处理:


对输入输出的处理

对于输出bbox,先将其由中心点坐标和宽高转换为矩形框左上角和右下角坐标的形式,同时,由于回归的是归一化后的值,因此需要根据图像尺寸转换为绝对坐标值。

现在,我们定义一个方法来封装整个推断过程,从而获取预测结果:


推断过程封装

这里有个点提一下,torch1.5版本中,对于tensor.max()的返回是torch.return_types.max(values=tensor(xxx), indices=tensor(xxx)),但是torch1.0中,这个方法的返回是一个tuple。

选用COCO数据集的类别,总共80类,但索引是1到90。

COCO数据集类别及用于可视化的颜色参数

上图中的COLORS用于画出bbox的矩形框颜色。

现在我们可以实例化一个模型,由于COCO的类别索引是1到90,因此我们的num_classes参数需要设置为91:

实例化模型与预训练权重加载

OK,一切准备就绪,我们现在来对一张图片进行检测:

对一张图进行检测

可以看到,模型在这张图中检测到了5个物体,最后我们对这个结果进行可视化:


可视化结果

最终效果如下图所示:

模型检测效果

#最后

CW认为,通过本文,应该可以让大家对 DETR 有个基本了解,当然同时可能也会产生出许多不解,不着急,更多的细节实现与原理会在后面的篇章中解析,待我酝酿酝酿,才香~

你可能感兴趣的:(源码解析目标检测的跨界之星DETR(一)、概述与模型推断)