源码解析目标检测的跨界之星DETR(三)、Backbone与位置编码

Date: 2020/07/14

Coder: CW

Foreword:

这一篇开始对 DETR 的模型构建部分进行解析,model主要由两部分组成,其中一部分是backbone,另一部分是Transformer。另外,在DETR的源码实现中,将位置编码模块与backbone集成到一起作为一个module,在backbone输出特征图的同时对其进行位置编码,以便后续Transformer使用。


Outline

I. Build Backbone

II. Build Position Encoding

III. Joiner


Build Backbone

backbone的构建通过bulid_backbone这个方法封装,主要做的就是分别构建位置编码部分与backbone,然后将两者封装到一个nn.Module里,在前向过程中实现两者的功能。

build_backbone

先来看backbone的构建,以下这个类继承BackboneBase这个类,实际的backbone是使用torchvision里实现的resnet。其中 pretrained=is_main_process() 代表仅在主进程中使用预训练权重。

Backbone

而对于 norm_layer=FrozenBatchNorm2d,代表这里使用的归一化层是FrozenBatchNorm2d,这个nn.Module与batch normalization的工作原理类似,只不过将统计量(均值与方差)和可学习的仿射参数固定住,doc string里的描述是:

BatchNorm2d where the batch statistics and the affine parameters are fixed.

在实现的时候,需要将以上4个量注册到buffer,以便阻止梯度反向传播而更新它们,同时又能够记录在模型的state_dict中。

FrozenBatchNorm2d

在BackboneBase中可以看到,若return_interm_layers设置为True,则需要记录每一层(ResNet的layer)的输出。

BackboneBase

注意,IntermediateLayerGetter 这个类是在torchvision中实现的,它继承nn.ModuleDict,接收一个nn.Module和一个dict作为初始化参数,dict的key对应nn.Module的模块,value则是用户自定义的对应各个模块输出的命名,官方给出的例子如下:

Examples::

        >>> m = torchvision.models.resnet18(pretrained=True)

        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`

        >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,

        >>>    {'layer1': 'feat1', 'layer3': 'feat2'})

        >>> out = new_m(torch.rand(1, 3, 224, 224))

        >>> print([(k, v.shape) for k, v in out.items()])

        >>>    [('feat1', torch.Size([1, 64, 56, 56])),

        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]

现在回过头来看看BackboneBase的前向过程。self.body就是上述提到的IntermediateLayerGetter,它的输出是一个dict,对应了每层的输出,key是用户自定义的赋予输出特征图的名字。

BackboneBase

注意BackboneBase的前向方法中的输入是NestedTensor这个类的实例,其实质就是将图像张量和对应的mask封装到一起。

NestedTensor

至此,Backbone的构建就完事了。综上可知,若设置return_interm_layers为True,即指定需要返回每层的输出,那么backbone的输出将是 out={'0': f1, '1': f2, '2': f3, '3': f4},否则输出将是 out={'0': f4},,其中代表第i层的输出(比如ResNet的话,就是)。


Build Position Encoding

这部分的实现和Transformer那篇paper中的基本类似,有两种方式来实现位置编码:一种是可学习的;另一种则是使用正、余弦函数来对各位置的奇、偶维度进行编码,不需要额外的参数进行学习,Transformer和DETR默认使用的也是这种。

不同的是,这里处理的对象是2D图像特征,而非1D的序列,因此位置编码需要分别对行、列进行(当然,你可能想到把特征图flatten成的1D序列,但应该是考虑到保持图像结构因此DETR并没有那么做,感兴趣的童鞋可以试试,并且进行实验对比下效果)。

build_position_encoding

先来对可学习的编码方式进行解析。

这里默认需要编码的特征图的行、列不超为50(相当于特征图尺寸不超过50x50),即位置索引在0~50范围内,对每个位置都嵌入到num_pos_feats(默认256)维。

PositionEmbeddingLearned(i)

下面是前向过程,分别对一行和一列中的每个位置进行编码。

PositionEmbeddingLearned(ii)

最后将行、列编码结果拼接起来并扩充第一维,与batch size对应,得到以下变量pos

在这种方式的编码下,所有行同一列的横坐标(x_emb)编码结果是一样的,在dim1中处于pos的前num_pos_feats维;同理,所有列所有列同一行的纵坐标(y_emb)编码结果也是一样的,在dim1中处于pos的后num_pos_feats维。

PositionEmbeddingLearned(iii)

再来看看正、余弦编码的方式。

这种方式是将每个位置的各个维度映射到角度上,因此有个scale参数,若初始化时没有指定,则默认为0~2π。

PositionEmbeddingSine(i)

在该系列上一篇: 源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理 中的数据处理部分,CW提到过mask指示了图像哪些位置是padding而来的,其值为True的部分就是padding的部分,这里取反后得到not_mask,那么值为True的部分就是图像真实有效(而非padding)的部分。

PositionEmbeddingSine(ii)

上图中使用了张量的cumsum()方法在列和行的方向分别进行累加,并且数据类型由布尔型转换为浮点型。不得不说,这个操作十分妙!

在行方向累加,就会得到以下形式(y_embed):

[ [1,1,1,..,1],

  [2,2,2,..,2],

  ...

  [h,h,h,..,h] ]

而在列方向累加,则得到以下形式(x_embed):

[ [1,2,3,..,w],

  [1,2,3,..,w],

  ...

  [1,2,3,..,w] ]

这样,各行(列)都映射到不同的值(当然,pad部分的位置由于mask取反后值是0,因此累加后得到的值会与前面行、列的值重复,但是不要紧,通过下一篇文章关于注意力权重计算的讲解会知道这些位置会被忽略掉而不受影响),并且,最后一行(列)是所有行(列)的总和h(w),还方便进行归一化操作:

PositionEmbeddingSine(iii)

下图部分的代码与正、余弦编码的公式对应:

PositionEmbeddingSine(iv)

使用这种方式编码,就是对各行各列的奇偶维度分别进行正、余弦编码。

对于每个位置(x,y),其所在列对应的编码值排在通道这个维度上的前num_pos_feats维,而其所在行对应的编码值则排在通道这个维度上的后num_pos_feats维。这样,特征图上的各位置(总共个)都对应到不同的维度为的编码值(向量)。

至于这种方式为何能够保证不同位置映射到不同的位置编码可以看下CW这篇文章中的分析(在最后一小节):Transformer 修炼之道(一)、Input Embedding

另外,这种方式相比于可学习的方式还有个“可拓展”的好处:即使在测试时来到一个比以往训练时遇到的图像尺寸都大的图像,也照样能够获得编码值。

Joiner

Joiner就是将backbone和position encoding集成的一个nn.Module里,使得前向过程中更方便地使用两者的功能。

Joinernn.Sequential的子类,通过初始化,使得self[0]是backbone,self[1]是position encoding。前向过程就是对backbone的每层输出都进行位置编码,最终返回backbone的输出及对应的位置编码结果。

Joiner

@最后

就这篇文章的整体内容来说,重点应该是在位置编码的部分,backbone部分毕竟是调用torchvision的内置模型,因此几乎没什么可讲的了,位置编码的技巧可以私下多码几遍,码熟它,从而掌握一个基本技能,便于日后应用。

你可能感兴趣的:(源码解析目标检测的跨界之星DETR(三)、Backbone与位置编码)