Deformable Detr代码阅读

前言

本文主要是自己在阅读mmdet中Deformable Detr的源码时的一个记录,如有错误或者问题,欢迎指正

deformable attention的流程

Deformable Detr代码阅读_第1张图片
首先zq即为object query,通过一个线性层,先预测出offset,后将三组offset添加到reference point上来得到采样后的位置,object query通过一个线性层和softmax,获取到attention weight(这就说明了deformable attention根本不需要用K点乘V来算attention weight,因为其attention weight是通过object query学到的),将attention weight与采样点的feature相乘,就得到了聚合后的value,在通过一个linear,就得到了output

提取feature map

Deformable Detr相对于detr的一个改进就是使用了多尺度的特征图,从配置文件中我们也可以看出

 backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(1, 2, 3),   # 使用了resnet的3层feature map
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=False),
        norm_eval=True,
        style='pytorch',
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
    neck=dict(
        type='ChannelMapper',
        in_channels=[512, 1024, 2048],
        kernel_size=1,
        out_channels=256,         # 将三层feature map的输出通道统一为256
        act_cfg=None,
        norm_cfg=dict(type='GN', num_groups=32),
        num_outs=4),

在代码层面,和DETR一样,首先是进入single_stage的forward_train中来提取feature map

super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                              gt_labels, gt_bboxes_ignore)

这里的x是resnet提取出来的全部四层feature map
在这里插入图片描述
然后进入到detr_head的forward_rain()中(因为是deformable detr head是基础了DETRHead),在DETRHead的forward_train()中,通过下面的代码的进入到deformable detr head的forward中

outs = self(x, img_metas)

deformable detr head

deformable detr head的整体逻辑和detr_head几乎相同,不同之处在于使用了多尺度的feature map。

生成mask矩阵

        batch_size = mlvl_feats[0].size(0)
        input_img_h, input_img_w = img_metas[0]['batch_input_shape']
        img_masks = mlvl_feats[0].new_ones(
            (batch_size, input_img_h, input_img_w))
  # 对于batch_size中的每一个图片,生成相应的原图的mask矩阵,将原始图像部分设置为0,1的位置表示pad部分
        for img_id in range(batch_size):
            img_h, img_w, _ = img_metas[img_id]['img_shape']
            img_masks[img_id, :img_h, :img_w] = 0
		
        mlvl_masks = []
        mlvl_positional_encodings = []
        #对原来的每个img_masks进行下采样,使其和相应的feature map大小相匹配
        for feat in mlvl_feats:
            mlvl_masks.append(
   #索引当中的None是增加维度的作用,img_masks扩充了一个维度:[b,h,w]-->[1,b,h,w]
                F.interpolate(img_masks[None],
                              size=feat.shape[-2:]).to(torch.bool).squeeze(0))
   # 生成positionan encoding,因为mlvl_masks每次append都是在最后一个,所以这里的索引每次取-1就好
            mlvl_positional_encodings.append(
                self.positional_encoding(mlvl_masks[-1]))

mlvl_feats如下所示,我这里batch_size为1
mlvl_fea
这里有一个点值得注意,就是为什么在进行F.interpolate之前要先使用img_masks[None]增加一个维度,这是因为F.interpolate函数对于要采样的矩阵的维度有要求,即为批量(batch_size)×通道(channel)×[可选深度]×[可选高度]×宽度(前两个维度具有特殊的含义,不进行采样处理)
参考:F.interpolate——数组采样操作

进入transformer

在deformable detr head的forward中,通过下面的代码进入transformer

query_embeds = None
if not self.as_two_stage:
            query_embeds = self.query_embedding.weight

hs, init_reference, inter_references, \
enc_outputs_class, enc_outputs_coord = self.transformer(
             mlvl_feats,
             mlvl_masks,
             query_embeds,   #[300,512]  [num_query,embed_dims * 2]
             mlvl_positional_encodings,
             reg_branches=self.reg_branches if self.with_box_refine else None,  # noqa:E501
             cls_branches=self.cls_branches if self.as_two_stage else None  # noqa:E501
            )

代码跳转到DeformableDetrTransformer的forward中,首先会进行一些进入transformer的准备工作

		feat_flatten = []
        mask_flatten = []
        lvl_pos_embed_flatten = []
        spatial_shapes = []
        # 将各个特征层的feature map,mask等拉直
        for lvl, (feat, mask, pos_embed) in enumerate(
                zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
            bs, c, h, w = feat.shape
            spatial_shape = (h, w)
            spatial_shapes.append(spatial_shape)
            feat = feat.flatten(2).transpose(1, 2) # [bs,h*w,c]
            mask = mask.flatten(1)  # [bs,h*w]
            pos_embed = pos_embed.flatten(2).transpose(1, 2)  # [bs,h*w,c]
            lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
            lvl_pos_embed_flatten.append(lvl_pos_embed)
            feat_flatten.append(feat)
            mask_flatten.append(mask)
        feat_flatten = torch.cat(feat_flatten, 1) # [bs,四层的h*w加起来,c]
        mask_flatten = torch.cat(mask_flatten, 1) # [bs,四层的h*w加起来]
        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # [bs,四层的h*w加起来,c]
        #转成tensor
        spatial_shapes = torch.as_tensor(
            spatial_shapes, dtype=torch.long, device=feat_flatten.device)
        # 记录每一层feature map的起始位置
        level_start_index = torch.cat((spatial_shapes.new_zeros(
            (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
         #得到每张特征图的有效宽高比例  [bs,4(num_levels),2(长和宽)]
        valid_ratios = torch.stack(
            [self.get_valid_ratio(m) for m in mlvl_masks], 1)

获取reference point

通过下面的函数获取reference point,最后得到的reference point是在0-1尺度上的值

    def get_reference_points(spatial_shapes, valid_ratios, device):
        """Get the reference points used in decoder.

        Args:
            spatial_shapes (Tensor): The shape of all
                feature maps, has shape (num_level, 2).
            valid_ratios (Tensor): The radios of valid
                points on the feature map, has shape
                (bs, num_levels, 2)
            device (obj:`device`): The device where
                reference_points should be.

        Returns:
            Tensor: reference points used in decoder, has \
                shape (bs, num_keys, num_levels, 2).
        """
        reference_points_list = []
        for lvl, (H, W) in enumerate(spatial_shapes):
            #  TODO  check this 0.5
            # 获取每个reference point中心横纵坐标,加减0.5是确保每个初始点是在每个pixel的中心
            ref_y, ref_x = torch.meshgrid(
                torch.linspace(
                    0.5, H - 0.5, H, dtype=torch.float32, device=device),
                torch.linspace(
                    0.5, W - 0.5, W, dtype=torch.float32, device=device))
            # 将横纵坐标进行归一化
            ref_y = ref_y.reshape(-1)[None] / (
                valid_ratios[:, None, lvl, 1] * H)
            ref_x = ref_x.reshape(-1)[None] / (
                valid_ratios[:, None, lvl, 0] * W)
            ref = torch.stack((ref_x, ref_y), -1)
            reference_points_list.append(ref)
        reference_points = torch.cat(reference_points_list, 1)
        # 将参考点的位置映射到有效区域
        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
        return reference_points

encoder

  memory = self.encoder(
    query=feat_flatten, # 输入query,是展平后的多尺度feature map [所有H*W的和, bs, 256]
    key=None,     #在self attention中,k和v是由q算出,因此输入为None
    value=None,
    query_pos=lvl_pos_embed_flatten, #输入query的位置编码, [所有H*W的和, bs, 256]
    query_key_padding_mask=mask_flatten, # padding mask [bs, 所有H*W的和]
    spatial_shapes=spatial_shapes, #每层feature map的h和w  [num_levels, bs]
    reference_points=reference_points, #[bs, 所有H*W的和, num_levels, 2]
    level_start_index=level_start_index,# 每层feature map展平后的第一个元素的位置索引 [num_levels]
    valid_ratios=valid_ratios, # 每层feature map对应的mask中有效的宽高比 [B, num_levels, 2]
    **kwargs)
# memory:encoder的输出,经过自注意力后的多尺度feature map [所有H*W的和, bs, 256]

进入encoder之后会按照在配置文件中的的顺序来

encoder=dict(
      type='DetrTransformerEncoder',
      num_layers=6,
      transformerlayers=dict(
      type='BaseTransformerLayer',
      attn_cfgs=dict(
      type='MultiScaleDeformableAttention', embed_dims=256),
      feedforward_channels=1024,
      ffn_dropout=0.1,
      operation_order=('self_attn', 'norm', 'ffn', 'norm'))),

这里的self-attn变成了MultiScaleDeformableAttention,
MultiScaleDeformableAttention的代码如下:在mmcv\ops\multi_scale_deform_attn.py中

        if value is None:
            value = query

        if identity is None:
            identity = query
        if query_pos is not None:
            query = query + query_pos
        if not self.batch_first:
            # change to (bs, num_query ,embed_dims)
            query = query.permute(1, 0, 2)
            value = value.permute(1, 0, 2)

        bs, num_query, _ = query.shape
        bs, num_value, _ = value.shape
        assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
		# value的值是从query中学到的,最开始的value为None,被赋值为query,然后通过一个线性层得到真正的value [bs,所有H*W的和,256]
        value = self.value_proj(value)
        if key_padding_mask is not None:
            value = value.masked_fill(key_padding_mask[..., None], 0.0)
        #[bs,所有H*W的和,256] ---> [bs,所有H*W的和,8,32]
        value = value.view(bs, num_value, self.num_heads, -1)
'''
self.sampling_offsets:
Linear(in_features=256, out_features=256, bias=True)
self.attention_weights:
Linear(in_features=256, out_features=128, bias=True)
'''

        # sampling_offsets : [bs,所有H*W的和, 8, 4, 4, 2]
        sampling_offsets = self.sampling_offsets(query).view(
            bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
        # attention_weights:[1, 10458, 8, 16]
        attention_weights = self.attention_weights(query).view(
            bs, num_query, self.num_heads, self.num_levels * self.num_points)
         # 为啥要softmax?
         # 经过一个线性层映射+softmax得到每个query的注意力权重
        attention_weights = attention_weights.softmax(-1)
		 #[1, 所有H*W的和, 8, 16] ---> [1,所有H*W的和,8,4,4]
        attention_weights = attention_weights.view(bs, num_query,
                                                   self.num_heads,
                                                   self.num_levels,
                                                   self.num_points)
        if reference_points.shape[-1] == 2:
    # 首先是sampling_offsets / offset_normalizer进行归一化 然后再和reference_points相加
            offset_normalizer = torch.stack(
                [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
            sampling_locations = reference_points[:, :, None, :, None, :] \
                + sampling_offsets \
                / offset_normalizer[None, None, None, :, None, :]
        elif reference_points.shape[-1] == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] \
                + sampling_offsets / self.num_points \
                * reference_points[:, :, None, :, None, 2:] \
                * 0.5
        else:
            raise ValueError(
                f'Last dim of reference_points must be'
                f' 2 or 4, but get {reference_points.shape[-1]} instead.')
                
         # 调用cuda算子进行deformable atten
        if torch.cuda.is_available() and value.is_cuda:
            output = MultiScaleDeformableAttnFunction.apply(
                value, spatial_shapes, level_start_index, sampling_locations,
                attention_weights, self.im2col_step)
        else:
            output = multi_scale_deformable_attn_pytorch(
                value, spatial_shapes, sampling_locations, attention_weights)

        output = self.output_proj(output)

        if not self.batch_first:
            # (num_query, bs ,embed_dims)
            output = output.permute(1, 0, 2)
        # 这个identity是上一次的query
        return self.dropout(output) + identity

在做完multi_scale_deformable_attn之后,会进行norm,ffn,norm,这样一个encoder layer就走完了,这个过程将重复6次,最后返回到DeformableDetrTransformer的forward中,返回值memory为encoder的输出,也即经过multi_scale_deformable_attn后的多尺度feature map,其维度为:[所有H*W的和, bs, 256]

decoder

inter_states, inter_references = self.decoder(
            query=query, # [num_query,bs,256]
            key=None,
            value=memory,  # encoder的输出 经过encoder后的feature map
            query_pos=query_pos,
            key_padding_mask=mask_flatten,
            reference_points=reference_points,
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            valid_ratios=valid_ratios,
            reg_branches=reg_branches,
            **kwargs)
        query_pos, query = torch.split(query_embed, c, dim=1)
        query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) #[bs,300,256]
        query = query.unsqueeze(0).expand(bs, -1, -1)#[bs,300,256]
        # 将query_pos经过一次线性变换+sigmoid正好能作为初始参考点坐标
        reference_points = self.reference_points(query_pos).sigmoid()
        init_reference_out = reference_points

        # decoder
        query = query.permute(1, 0, 2) #[300(num_query),bs,256]
        memory = memory.permute(1, 0, 2) #[所有H*W的和,bs,256]
        query_pos = query_pos.permute(1, 0, 2)#[300(num_query),bs,256]
        inter_states, inter_references = self.decoder(
            query=query, #[300(num_query),bs,256]
            key=None,
            value=memory,#经过encoder的feature map
            query_pos=query_pos,
            key_padding_mask=mask_flatten,
            reference_points=reference_points,
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            valid_ratios=valid_ratios,
            reg_branches=reg_branches, #None
            **kwargs)

进入到self.decoder中后,代码跳转到DeformableDetrTransformerDecoder中的forward函数中,在mmdetection/mmdet/models/utils/transformer.py中

		output = query
        intermediate = [] #存储每层decoder layer的query
        intermediate_reference_points = [] # 用来存储每层decoder layer的reference_points
        for lid, layer in enumerate(self.layers):
            if reference_points.shape[-1] == 4:
                reference_points_input = reference_points[:, :, None] * \
                    torch.cat([valid_ratios, valid_ratios], -1)[:, None]
            else:
                assert reference_points.shape[-1] == 2
                reference_points_input = reference_points[:, :, None] * \
                    valid_ratios[:, None]
            output = layer(
                output,   # query
                *args,
                reference_points=reference_points_input,
                **kwargs) 
# kwargs包含了['key', 'value', 'query_pos', 'key_padding_mask', 'spatial_shapes', 'level_start_index']
# key为None ,value为从encoder中得到的memory
            output = output.permute(1, 0, 2)
            # reg_branches默认问None
            if reg_branches is not None:
                tmp = reg_branches[lid](output)
                if reference_points.shape[-1] == 4:
                    new_reference_points = tmp + inverse_sigmoid(
                        reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                else:
                    assert reference_points.shape[-1] == 2
                    new_reference_points = tmp
                    new_reference_points[..., :2] = tmp[
                        ..., :2] + inverse_sigmoid(reference_points)
                    new_reference_points = new_reference_points.sigmoid()
                reference_points = new_reference_points.detach()

            output = output.permute(1, 0, 2)
            # 将中间的query和reference_point存下来,query有更新,reference_points其实每一层都是一样的
            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)

        if self.return_intermediate:  # true
            return torch.stack(intermediate), torch.stack(
                intermediate_reference_points)

        return output, reference_points

decoder最后返回两个值,也即所有六层decoder的query和reference_points,每一层的query是不同的,但是每一层的referen_points是相同的
Deformable Detr代码阅读_第2张图片
最后整个transformer返回三个值,inter_states,init_reference_out,inter_references_out
inter_states :[num_dec_layers, bs, num_query, embed_dims] 表示每个decode layer的query
init_reference_out : [bs,num_query,2] 表示最开始的reference_points
inter_references_out:[num_dec_layers, bs, num_query, embed_dims] 表示每一层的reference points
Deformable Detr代码阅读_第3张图片

预测部分

在经过了transformer部分之后,代码回到了deformable detr head中

		hs = hs.permute(0, 2, 1, 3)
        outputs_classes = []
        outputs_coords = []
		
		# 逐个decoder layer去做预测
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)  # 做反sigmoid
            outputs_class = self.cls_branches[lvl](hs[lvl])
            # 这里预测出的tmp是相对于reference的offset
            tmp = self.reg_branches[lvl](hs[lvl])
            if reference.shape[-1] == 4:
                tmp += reference
            else:
                assert reference.shape[-1] == 2
                tmp[..., :2] += reference   #reference与预测出的offset相加
            outputs_coord = tmp.sigmoid()
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)

        outputs_classes = torch.stack(outputs_classes)
        outputs_coords = torch.stack(outputs_coords)
        if self.as_two_stage:
            return outputs_classes, outputs_coords, \
                enc_outputs_class, \
                enc_outputs_coord.sigmoid()
        else:
            return outputs_classes, outputs_coords, \
                None, None

后面就是计算loss了,这部分和DETR应该是一样的,我在DETR的源码阅读中已经写过了,这里就不写了,感兴趣的可以去看我的另一篇博客:DETR源码阅读

一些细节:

encoder时候的只有self_atten,QKV都是feature map
decoder时候,self_atten时候,QKV都是object query([num_query,bs,256])
cross_atten时候,Q是object query V是feature map,K这里是None,因为deformable atten不需要通过Q点乘K来获取attention_weight,其attention_weight是通过object query学出来的

你可能感兴趣的:(源码阅读,python,深度学习,人工智能)