本文首先简要介绍Detr论文原理,之后在介绍mmdetection中的detr实现。
论文地址:chrome-extension://ibllepbpahcoppkjjllbabhnigcbffpi/https://arxiv.53yu.com/pdf/2005.12872.pdf,
整体流程:在给定一张输入图像后,1)特征向量提取:首先经过ResNet提取图像的最后一层特征图F。注意此处仅仅用了一层特征图,是因为后续计算复杂度原因,另外,由于仅用最后一层特征图,故对小目标检测不友好,这也是后续deformable detr改进的原因。 2)添加位置编码信息:经F拉平成一维张量并添加上位置编码信息得到I。3)Transformer中encoder部分4)Transformer中decoder部分,学习位置嵌入object queries。5)FFN部分:6)后续匈牙利匹配+损失计算。
Detr的内部逻辑如下:在mmdet/models/detector/single_stage.py。即首先提取图像特征向量,之后经过DetrHead来计算最终的损失。
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img) # 提取图像特征向量
# 经过DetrHead得到loss
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_bboxes_ignore)
return losses
mmdet中提取图像特征向量的config配置文件如下,可以发现用ResNet50并只提取了最后一层特征层,即out_indices=(3,)。关于内部原理参见我的博文:mmdet之backbone介绍。
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3, ), # detr仅要resnet50的最后一层特征图,并不需要FPN
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))
本部分代码来自mmdet/models/dense_heads/detr_head.py。
mmdet中生成位置编码信息借助的是mask矩阵,所谓的mask就是为了统一批次大小而对图像进行了pad,被填充的部分在后续计算多头注意力时应该舍弃,故需要一个mask矩阵遮挡住,具体形状为[batch, h,w]这里先贴下生成mask的过程:
batch_size = x.size(0)
input_img_h, input_img_w = img_metas[0]['batch_input_shape']# 一个批次图像大小
masks = x.new_ones((batch_size, input_img_h, input_img_w)) # [b,838,768]
for img_id in range(batch_size):
img_h, img_w, _ = img_metas[img_id]['img_shape'] # 创建了一个mask,非0代表无效区域, 0 代表有效区域
masks[img_id, :img_h, :img_w] = 0 # 将pad部分置为1,非pad部分置为0.
我这里简单贴下mask示意图:
在有了mask基础上[batch,256,h,w],注意此时的hw是原图大小的;而输入图像的经过resnet50下采样后hw已经变了,所以还需进一步将mask下采样成和图像特征向量一样的shape。代码如下:
# interpolate masks to have the same spatial shape with x
masks = F.interpolate(
masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1) # masks和x的shape一样:[b,27,24]
后续便可以生成位置编码部分(mmdet/models/utils/position_encoding.py),该函数给masks的每个像素位置生成了一个256维的唯一的位置向量。我这简单写了个测试脚本:
import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import build_positional_encoding
from mmdet.models.utils.positional_encoding import POSITIONAL_ENCODING # 加载注册器
positional_encoding = dict(type='SinePositionalEncoding', num_feats=128, normalize=True)
self = build_positional_encoding(positional_encoding)
self.eval()
mask = torch.tensor([[[0,0,1],[0,0,1],[1,1,1]]], dtype= torch.uint8) # [1,3,3]
out = self(mask) # [b,256,h,w]
感兴趣可以看下mmdet关于位置编码这部分实现逻辑(只是做了简单注释):
def forward(self, mask):
"""Forward function for `SinePositionalEncoding`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, h, w].
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""
# For convenience of exporting to ONNX, it's required to convert
# `masks` from bool to int.
mask = mask.to(torch.int)
not_mask = 1 - mask # 取反将1的位置视为图像区域
y_embed = not_mask.cumsum(1, dtype=torch.float32) # 累加1得到y方向坐标 [h]
x_embed = not_mask.cumsum(2, dtype=torch.float32) # 累加1得到x方向坐标 [w]
# 归一化过程就是除以坐标中的max,而最后一行/列就是累加的最大的向量
if self.normalize:
y_embed = (y_embed + self.offset) / \
(y_embed[:, -1:, :] + self.eps) * self.scale # 取最后一行
x_embed = (x_embed + self.offset) / \
(x_embed[:, :, -1:] + self.eps) * self.scale # 取最后一列
# 创建一个[128]的特征向量
dim_t = torch.arange(
self.num_feats, dtype=torch.float32, device=mask.device)
dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) # 归一化的 [0,0,1,1,2,2,...,64,64]乘温度系数
# 行列坐标分别除以dim_t得到每个点的128维的行列特征向量
pos_x = x_embed[:, :, :, None] / dim_t # [b,h,w,128]
pos_y = y_embed[:, :, :, None] / dim_t # [b,h,w,128]
# use `view` instead of `flatten` for dynamically exporting to ONNX
B, H, W = mask.size()
# 分别采样奇数和偶数位置并执行sin和cos,并拼接[b,h,w,128]
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
dim=4).view(B, H, W, -1)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
dim=4).view(B, H, W, -1)
# 最后将横纵坐标拼接得到每个点唯一的256维度的位置向量
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) # [b,256,h,w]
return pos
在得到图像特征向量x=[b,c,h,w]、masks[b,h,w]矩阵以及位置编码pos_embed[b,256,h,w]后,便可送入Transformer,关键是厘清encoder和decoder的QKV分别指啥,看代码:
memory = self.encoder(
query=x, # [hw,b,c]
key=None,
value=None,
query_pos=pos_embed, # [hw,b,c]
query_key_padding_mask=mask) # [b,hw]
target = torch.zeros_like(query_embed) # decoder初始化全0
# out_dec: [num_layers, num_query, bs, dim]
out_dec = self.decoder(
query=target, # 全0的target, 后续在MultiHeadAttn中执行了
key=memory, # query = query + query_pos又加回去了。
value=memory,
key_pos=pos_embed,
query_pos=query_embed, # [num_query, bs, dim]
key_padding_mask=mask)
out_dec = out_dec.transpose(1, 2)
其中encoder中q就是x,kv分别为None,query_pos代表位置编码,而query_key_padding_mask就是mask。decoder的q是全0的target,后续decoder会迭代更新q,而kv则 是memory,即encoder的输出;key_pos依旧是k的位置信息;query_embed即论文中Object query,可学习位置信息;key_padding_mask依然是mask。
先看下encoder初始化部分,内部循环调用了6次BaseTransformerLayer,因此只需讲解一层EncoderLayer即可。
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6, # 经过6层Layer
transformerlayers=dict( # 每层layer内部使用多头注意力
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1)
],
feedforward_channels=2048, # FFN中间层的维度
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))), # 定义运算流程
在来看下BaseTransformerLayer的forward部分,该部分可以损失detr的核心部分了,因为本质上mmdet内部只是封装了pytorch现有的nn.MultiHeadAtten函数。所以,需要理解nn.MultiHeadAttn中两种mask参数的含义,限于篇幅原因,这里可参考nn.Transformer来理解这两个mask。 不过简单理解就是:attn_mask在detr中没用到,仅用key_padding_mask。attn_mask是为了遮挡未来文本信息用的,而图像可以看到全部的信息,因此不需要用attn_mask。
def forward(self,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index = 0
attn_index = 0
ffn_index = 0
identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
]
warnings.warn(f'Use same attn_mask in all attentions in '
f'{self.__class__.__name__} ')
else:
assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \
f'to the number of attention in ' \
f'operation_order {self.num_attn}'
for layer in self.operation_order: # 遍历config文件的顺序
if layer == 'self_attn':
temp_key = temp_value = query
query = self.attentions[attn_index]( # 内部调用nn.MultiHeadAttn
query,
temp_key,
temp_value,
identity if self.pre_norm else None,
query_pos=query_pos, # 若有位置编码信息则和query相加
key_pos=query_pos, # 若有位置编码信息则和key相加
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs)
attn_index += 1
identity = query
elif layer == 'norm':
query = self.norms[norm_index](query) # 层归一化
norm_index += 1
elif layer == 'cross_attn': # decoder用到
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query_pos=query_pos, # 若有位置编码信息则和query相加
key_pos=key_pos, # 若有位置编码信息则和key相加
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs)
attn_index += 1
identity = query
elif layer == 'ffn': # 残差连接加全连接层
query = self.ffns[ffn_index](
query, identity if self.pre_norm else None)
ffn_index += 1
return query
decoder部分和encoder流程类似,只是多了交叉注意力。
decoder=dict(
type='DetrTransformerDecoder',
return_intermediate=True,
num_layers=6,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
feedforward_channels=2048,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm','ffn', 'norm')),
))
我这里简单贴下nn.MultiHeadAttn内部流程:
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # 计算Q*K
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] # 判断一个tensor的shape是否等于某个尺寸,将其转成list。
# 利用attn_mask将未来的词遮挡住
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
# 借助key_padding_mask将pad部分遮挡住
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) # [2,8,5,5]
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
上述代码流程比较简单,就是首先计算Q中每个元素和K的相似度,要依次用两种mask来遮挡住,为后续的softmax做准备。以cross attn为例,attn_output_weights是计算了每个真实单词和原始句子每个单词的相似性权重,所以要用和src_key_padding_mask一样的memory_key_padding_mask在行的维度上进行遮挡,故二者pad_mask是一致的。
由于后续在detr上改进的论文对匈牙利算法以及loss计算改动不大,因此这部分代码就不讲解了。 感觉写的已经够乱了,哭脸。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。