本文主要是自己在阅读mmdet中Deformable Detr的源码时的一个记录,如有错误或者问题,欢迎指正
首先zq即为object query,通过一个线性层,先预测出offset,后将三组offset添加到reference point上来得到采样后的位置,object query通过一个线性层和softmax,获取到attention weight(这就说明了deformable attention根本不需要用K点乘V来算attention weight,因为其attention weight是通过object query学到的),将attention weight与采样点的feature相乘,就得到了聚合后的value,在通过一个linear,就得到了output
Deformable Detr相对于detr的一个改进就是使用了多尺度的特征图,从配置文件中我们也可以看出
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3), # 使用了resnet的3层feature map
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256, # 将三层feature map的输出通道统一为256
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
在代码层面,和DETR一样,首先是进入single_stage的forward_train中来提取feature map
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_bboxes_ignore)
这里的x是resnet提取出来的全部四层feature map
然后进入到detr_head的forward_rain()中(因为是deformable detr head是基础了DETRHead),在DETRHead的forward_train()中,通过下面的代码的进入到deformable detr head的forward中
outs = self(x, img_metas)
deformable detr head的整体逻辑和detr_head几乎相同,不同之处在于使用了多尺度的feature map。
batch_size = mlvl_feats[0].size(0)
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
img_masks = mlvl_feats[0].new_ones(
(batch_size, input_img_h, input_img_w))
# 对于batch_size中的每一个图片,生成相应的原图的mask矩阵,将原始图像部分设置为0,1的位置表示pad部分
for img_id in range(batch_size):
img_h, img_w, _ = img_metas[img_id]['img_shape']
img_masks[img_id, :img_h, :img_w] = 0
mlvl_masks = []
mlvl_positional_encodings = []
#对原来的每个img_masks进行下采样,使其和相应的feature map大小相匹配
for feat in mlvl_feats:
mlvl_masks.append(
#索引当中的None是增加维度的作用,img_masks扩充了一个维度:[b,h,w]-->[1,b,h,w]
F.interpolate(img_masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
# 生成positionan encoding,因为mlvl_masks每次append都是在最后一个,所以这里的索引每次取-1就好
mlvl_positional_encodings.append(
self.positional_encoding(mlvl_masks[-1]))
mlvl_feats如下所示,我这里batch_size为1
这里有一个点值得注意,就是为什么在进行F.interpolate之前要先使用img_masks[None]增加一个维度,这是因为F.interpolate函数对于要采样的矩阵的维度有要求,即为批量(batch_size)×通道(channel)×[可选深度]×[可选高度]×宽度(前两个维度具有特殊的含义,不进行采样处理)
参考:F.interpolate——数组采样操作
在deformable detr head的forward中,通过下面的代码进入transformer
query_embeds = None
if not self.as_two_stage:
query_embeds = self.query_embedding.weight
hs, init_reference, inter_references, \
enc_outputs_class, enc_outputs_coord = self.transformer(
mlvl_feats,
mlvl_masks,
query_embeds, #[300,512] [num_query,embed_dims * 2]
mlvl_positional_encodings,
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
)
代码跳转到DeformableDetrTransformer的forward中,首先会进行一些进入transformer的准备工作
feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
# 将各个特征层的feature map,mask等拉直
for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
bs, c, h, w = feat.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose(1, 2) # [bs,h*w,c]
mask = mask.flatten(1) # [bs,h*w]
pos_embed = pos_embed.flatten(2).transpose(1, 2) # [bs,h*w,c]
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat)
mask_flatten.append(mask)
feat_flatten = torch.cat(feat_flatten, 1) # [bs,四层的h*w加起来,c]
mask_flatten = torch.cat(mask_flatten, 1) # [bs,四层的h*w加起来]
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # [bs,四层的h*w加起来,c]
#转成tensor
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=feat_flatten.device)
# 记录每一层feature map的起始位置
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
#得到每张特征图的有效宽高比例 [bs,4(num_levels),2(长和宽)]
valid_ratios = torch.stack(
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
通过下面的函数获取reference point,最后得到的reference point是在0-1尺度上的值
def get_reference_points(spatial_shapes, valid_ratios, device):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
# TODO check this 0.5
# 获取每个reference point中心横纵坐标,加减0.5是确保每个初始点是在每个pixel的中心
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(
0.5, W - 0.5, W, dtype=torch.float32, device=device))
# 将横纵坐标进行归一化
ref_y = ref_y.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (
valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
# 将参考点的位置映射到有效区域
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
memory = self.encoder(
query=feat_flatten, # 输入query,是展平后的多尺度feature map [所有H*W的和, bs, 256]
key=None, #在self attention中,k和v是由q算出,因此输入为None
value=None,
query_pos=lvl_pos_embed_flatten, #输入query的位置编码, [所有H*W的和, bs, 256]
query_key_padding_mask=mask_flatten, # padding mask [bs, 所有H*W的和]
spatial_shapes=spatial_shapes, #每层feature map的h和w [num_levels, bs]
reference_points=reference_points, #[bs, 所有H*W的和, num_levels, 2]
level_start_index=level_start_index,# 每层feature map展平后的第一个元素的位置索引 [num_levels]
valid_ratios=valid_ratios, # 每层feature map对应的mask中有效的宽高比 [B, num_levels, 2]
**kwargs)
# memory:encoder的输出,经过自注意力后的多尺度feature map [所有H*W的和, bs, 256]
进入encoder之后会按照在配置文件中的的顺序来
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention', embed_dims=256),
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
这里的self-attn变成了MultiScaleDeformableAttention,
MultiScaleDeformableAttention的代码如下:在mmcv\ops\multi_scale_deform_attn.py中
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
# value的值是从query中学到的,最开始的value为None,被赋值为query,然后通过一个线性层得到真正的value [bs,所有H*W的和,256]
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
#[bs,所有H*W的和,256] ---> [bs,所有H*W的和,8,32]
value = value.view(bs, num_value, self.num_heads, -1)
'''
self.sampling_offsets:
Linear(in_features=256, out_features=256, bias=True)
self.attention_weights:
Linear(in_features=256, out_features=128, bias=True)
'''
# sampling_offsets : [bs,所有H*W的和, 8, 4, 4, 2]
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
# attention_weights:[1, 10458, 8, 16]
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
# 为啥要softmax?
# 经过一个线性层映射+softmax得到每个query的注意力权重
attention_weights = attention_weights.softmax(-1)
#[1, 所有H*W的和, 8, 16] ---> [1,所有H*W的和,8,4,4]
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
# 首先是sampling_offsets / offset_normalizer进行归一化 然后再和reference_points相加
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
# 调用cuda算子进行deformable atten
if torch.cuda.is_available() and value.is_cuda:
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
# 这个identity是上一次的query
return self.dropout(output) + identity
在做完multi_scale_deformable_attn之后,会进行norm,ffn,norm,这样一个encoder layer就走完了,这个过程将重复6次,最后返回到DeformableDetrTransformer的forward中,返回值memory为encoder的输出,也即经过multi_scale_deformable_attn后的多尺度feature map,其维度为:[所有H*W的和, bs, 256]
inter_states, inter_references = self.decoder(
query=query, # [num_query,bs,256]
key=None,
value=memory, # encoder的输出 经过encoder后的feature map
query_pos=query_pos,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=reg_branches,
**kwargs)
query_pos, query = torch.split(query_embed, c, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) #[bs,300,256]
query = query.unsqueeze(0).expand(bs, -1, -1)#[bs,300,256]
# 将query_pos经过一次线性变换+sigmoid正好能作为初始参考点坐标
reference_points = self.reference_points(query_pos).sigmoid()
init_reference_out = reference_points
# decoder
query = query.permute(1, 0, 2) #[300(num_query),bs,256]
memory = memory.permute(1, 0, 2) #[所有H*W的和,bs,256]
query_pos = query_pos.permute(1, 0, 2)#[300(num_query),bs,256]
inter_states, inter_references = self.decoder(
query=query, #[300(num_query),bs,256]
key=None,
value=memory,#经过encoder的feature map
query_pos=query_pos,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=reg_branches, #None
**kwargs)
进入到self.decoder中后,代码跳转到DeformableDetrTransformerDecoder中的forward函数中,在mmdetection/mmdet/models/utils/transformer.py中
output = query
intermediate = [] #存储每层decoder layer的query
intermediate_reference_points = [] # 用来存储每层decoder layer的reference_points
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = reference_points[:, :, None] * \
torch.cat([valid_ratios, valid_ratios], -1)[:, None]
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * \
valid_ratios[:, None]
output = layer(
output, # query
*args,
reference_points=reference_points_input,
**kwargs)
# kwargs包含了['key', 'value', 'query_pos', 'key_padding_mask', 'spatial_shapes', 'level_start_index']
# key为None ,value为从encoder中得到的memory
output = output.permute(1, 0, 2)
# reg_branches默认问None
if reg_branches is not None:
tmp = reg_branches[lid](output)
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(
reference_points)
new_reference_points = new_reference_points.sigmoid()
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[
..., :2] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
output = output.permute(1, 0, 2)
# 将中间的query和reference_point存下来,query有更新,reference_points其实每一层都是一样的
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate: # true
return torch.stack(intermediate), torch.stack(
intermediate_reference_points)
return output, reference_points
decoder最后返回两个值,也即所有六层decoder的query和reference_points,每一层的query是不同的,但是每一层的referen_points是相同的
最后整个transformer返回三个值,inter_states,init_reference_out,inter_references_out
inter_states :[num_dec_layers, bs, num_query, embed_dims] 表示每个decode layer的query
init_reference_out : [bs,num_query,2] 表示最开始的reference_points
inter_references_out:[num_dec_layers, bs, num_query, embed_dims] 表示每一层的reference points
在经过了transformer部分之后,代码回到了deformable detr head中
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
# 逐个decoder layer去做预测
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference) # 做反sigmoid
outputs_class = self.cls_branches[lvl](hs[lvl])
# 这里预测出的tmp是相对于reference的offset
tmp = self.reg_branches[lvl](hs[lvl])
if reference.shape[-1] == 4:
tmp += reference
else:
assert reference.shape[-1] == 2
tmp[..., :2] += reference #reference与预测出的offset相加
outputs_coord = tmp.sigmoid()
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_classes = torch.stack(outputs_classes)
outputs_coords = torch.stack(outputs_coords)
if self.as_two_stage:
return outputs_classes, outputs_coords, \
enc_outputs_class, \
enc_outputs_coord.sigmoid()
else:
return outputs_classes, outputs_coords, \
None, None
后面就是计算loss了,这部分和DETR应该是一样的,我在DETR的源码阅读中已经写过了,这里就不写了,感兴趣的可以去看我的另一篇博客:DETR源码阅读
encoder时候的只有self_atten,QKV都是feature map
decoder时候,self_atten时候,QKV都是object query([num_query,bs,256])
cross_atten时候,Q是object query V是feature map,K这里是None,因为deformable atten不需要通过Q点乘K来获取attention_weight,其attention_weight是通过object query学出来的