【BEV】学习笔记之 DeformableDETR(原理+代码解析)

1、前言

Vision transforer(ViT)是Google团队提出的将transformer应用在图像分类的模型,成为了transformer在CV领域应用的里程碑著作。而DETR成功的transformer引入到目标检测任务中,也为后来的BEV的模型奠定了基础。Deformable DETR的是针对DETR训练慢、小目标检测差的问题而提出来的,同时进一步推进了BEV模型的发展,因此本文将对Deformable DETR进行解析,如有不对的地方,还望大佬们指出。

repo :https://github.com/fundamentalvision/Deformable-DETR

paper:https://arxiv.org/pdf/2010.04159

学习文章和视频 (强烈推荐):transformer、Vision transformer、DETR

欢迎进入BEV感知交流群,一起解决学习过程发现的问题,可以加v:Rex1586662742或者q群:468713665。

2、模型简介

在transformer中,特征点的特征向量Value可以由一个网络学习到,但是这个Value并不能表示全局的建模关系,于是就由另外两个网络为分别为每个特征点学习一个query和key,然后利用当前特征点的query与所有特征点的key做点乘,然后进行softmax,这样可以计算出每个特征点与其他特征点的权重关系,然后利用这个权重关系,将所有特征点的Value进行加权求和,得到每个特征点最终的Value。实时上,每个特征点并非需要与其他每个特征点做self-attention,比如图片上的左上角的特征点与右下角的特征点的关系是十分微弱的,甚至毫无关系。

因此在Deformable DETR 中,每个特征点只与周围的几个特征点(默认为4)进行self-attention,也就是每个特征点的Value是由其周围4个特征点的的Value加权求和得到的。相对于DETR,在Deformable DETR中,引入了多尺度的特征(能够同时兼顾大目标与小目标的识别),因此每个特征点都能够在每个特征层上找到一个自己的采样点,然后在每个采样点周围采样4个偏移点作为self-attention的对象,即利用 4 * 4 = 16 个偏移点特征向量Value来计算当前特征点的Value。这里有个问题,在transformer中,当前特征点的Value加权求和时是将自己的Value包括在内的,而在deformable detr中,是将自己value除外的,这样做的好处是?

现在已经基本明确在 Deformable DETR中, 特征点要与哪些偏移点怎么做self-attention了,那么后续可以分为两个部分:1、如何找到这些采偏移点,2、这些偏移点的权重系数是多少。文章是利用两个网络来实现的,一个网络通过特征点的Vale预测16个偏移点的位置,另一个网络利用特征点的Value预测16个偏移点的权重系数,如下图所示。
【BEV】学习笔记之 DeformableDETR(原理+代码解析)_第1张图片
初看论文时,一直没理解论文中的图所表达的意思,后来有了一定的了解值之后,就根据自己的理解绘制的上方的图片,再回过去看论文中的图片的时候,就变得异常清晰了。图中左边所示为4种尺度的特征层,以最上方特征层中的一个特征点(0.3,0.3)为例,它在每个特征层上都有一个采样点(相对坐标一致),正常来说每个采样点会与周围的四个点(绿色点)进行self-attention,但是这四个点最好的通过网络自己来学习,于是蓝色的点是网络学习到的偏移点,但是偏移点的坐标一般不会为整数,因此,蓝色特征点的Value就会有其附近的四个特征点(黄色)进行双线性差值得到,因此,一个特征点就采样到了16个偏移点,那么这个特征点的特征向量Value就由这16个偏移点的特征向量Value加权求和得到,这里也对应上了论文中的第二张图,下面的代码会有详细的解释。

3、代码解析

Deformable DETR的代码很容易就能运行起来,可以参考 Deformable DETR 实战(训练及预测)

1、predict.py

def detect(...):
    outputs = model(img)
    # -> models/deformable_detr.py

    # 将box恢复到原图尺寸
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)

2、models/deformable_detr.py

class DeformableDETR(...):
    def __init__(...):
        ...
        
    def nested_tensor_from_tensor_list(...):
        ...
        
    def forward(...):
        """
        Args:
            samples:[1,3,1066,800] 测试图片
        """
        # features三个不同尺度的特征层,pos三个特征层的位置编码
        features, pos = self.backbone(samples)
        for l, feat in enumerate(features):
            # src 不同尺度的特征层
            # ...下面的mask是为了支持 不同batch内不同尺寸的图片,后续不再说明,一般情况下,mask不起作用。
            src, mask = feat.decompose()  # src:[1,256,134,100]、[1,256,67,50]、[1,256,34,25]
        
        if self.num_feature_levels > len(srcs):
            for l in range(_len_srcs, self.num_feature_levels):
                if ...:
                    ...
                else:
                    src = self.input_proj[l](srcs[-1])  # 对最后一个特征层进行下采样
         			srcs.append(src)  # 4个不同尺度的特征层
					pos.append(pos_l) # 4个特征层的位置编码
        
        if not self.two_stage:
            # [300,256] 为 DETR中的object_query, 另外[300,256]为object_query的位置编码
            query_embeds = self.query_embed.weight #[300,512]
        # hs:[6,1,300,256] 6次decode block中的结果
        # init_reference:[1,300,2] 300个object_query参考坐标
        # inter_references:[6,1,300,2] #经6次过block调整过后的 object_query的坐标
        hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, query_embeds)
        # -> models/deformable_transformer.py
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)  # 由于之前获得init_reference是经过sigmod的,现在要还原
            # 对每个block的结果预测  tgt(object_quert) 的类别 [1,300,91]
            outputs_class = self.class_embed[lvl](hs[lvl])
            
            # 对每个block的结果预测 tgt(object_quert) 的BBOX,偏移量,[x,y,w,h]   [1,300,4]
            tmp = self.bbox_embed[lvl](hs[lvl])
            if reference.shape[-1] == 4:
                ...
            else:
                tmp[..., :2] += reference # 加上参考点
                # 归一化坐标
                outputs_coord = tmp.sigmoid()
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            # 在训练是使用 aux_loss
            ...
        return out

3、models/deformable_transformer.py

class DeformableTransformer(...):
    def __init__(...):
        ...
        
    def forward(...):
        """
        Argc:
        	srcs:4个不同尺度的特征层
        	pos_embeds:固定的位置编码
        	query_embed:解码时使用的query [300,256]
        """
        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
            src = src.flatten(2).transpose(1, 2) # 该层所有特征点的特征 [1, n, 256] 
            pos_embed = pos_embed.flatten(2).transpose(1, 2) # 该层特征点固定的位置编码 [1,13400,256]
            # self.level_embed 针对每一层的顺序进行可学习的位置编码
            lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) # 因为一共有4个尺度的特征层,因此还要加上不同尺度的位置编码
            ...
        
        src_flatten = torch.cat(src_flatten, 1) # [1, 17821, 256] 四个特征层上所有点的特征 
        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) #[1, 17821, 256] 所有特征点的位置编码
        level_start_index = ...  # [0, 13400, 16750, 17600] 记录每个特征层在17821里面的位置
        # memory: [1,17821,256]
        memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
        # -> 下方的DeformableTransformerEncoder
        
        if self.two_stage:
            ...
        else:
            query_embed, tgt = torch.split(...)  # query_embed [1,300,256]-> object-query的可学习的位置编码   tgt [1,300,256] ->  object-query
            reference_points = self.reference_points(query_embed).sigmoid()
            init_reference_out = reference_points  # [1,300,2]
        # hs:[6, 1, 300, 256] 6次decode的结果  inter_references:[6,1,300,4,2] 6次decode的object query 在特征层上的参考坐标
        hs, inter_references = self.decoder(tgt, reference_points, memory,spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten)
        #  -> 下方 DeformableTransformerDecoder

        return hs, init_reference_out, inter_references_out, None, None
    
class DeformableTransformerEncoder(...):
    def __init__:
        ...
    
    def forward(...)
        """
        Args:
            src: [1, 17821, 256]  # 所有特征点的特征向量
            spatial_shapes:       # 每个特征层的尺寸
            level_start_index:	  # 用于将src里面的特征点划分到对应的特征层上
        """
        # 
        reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
        for _, layer in enumerate(self.layers):
            output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
            #  -> 下方的 class DeformableTransformerEncoderLayer
            
    
    def get_reference_points(...):
        
        # 遍历每个特征层的尺寸
        for lvl, (H_, W_) in enumerate(spatial_shapes):
            #按照0.5的距离,划分H_,W_
            ref_y, ref_x = ...
            ref_y = ref_y...  # 将ref_y 归一化到 (0,1)之间
            ref_x = ref_x...  # 将ref_x 归一化到 (0,1)之间
            ...
            
        reference_points = ... # [1, 17821, 4 , 2]  # 17821个特征点 在4个特征层上的归一化坐标,(采样点),每个坐标为(x,y) 
        return reference_points
        
class DeformableTransformerEncoderLayer(...):
    def __init__(...):
        ...
    
    # 位置编码
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos
    
    def forward(...):
        """
        Args:
            src:[1, 17821, 256]
            pos:[1, 17821, 256] 位置编码
            reference_points:[1, 17821, 4, 2]
            spatial_shapes:特征层的尺寸
            level_start_index:...
        """
        # Deformable transformer
        src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
        #  -> models/ops/modules/ms_deform_attn.py
        
        
class DeformableTransformerDecoder(...):
    def __init__(...):
        ...
    
    def forward(...):
        """
        Args:
            reference_points_input:[1,300,4,2] #预测的300个目标在每个特征层上的坐标
        """
        output = tgt  # tgt -> 每个block的中间结果
        for lid, layer in enumerate(self.layers):
            # 重复六次  out:[1,300,256]
            output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
            #  ->  DeformableTransformerDecoderLayer
            ...
            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)
        
        if self.return_intermediate:
            # [6, 1, 300, 256] 6次decode的结果 [6,1,300,4,2] -> 6次decode的参考点坐标
            return torch.stack(intermediate), torch.stack(intermediate_reference_points)
class DeformableTransformerDecoderLayer(...):
    def __init__(...):
        ...
    
    def forward(...):
        """
        Args:
            tgt:[1, 300, 256] 
            query_pos:[1, 300, 256] 
        """
        # self attention中的 K,Q   
        q = k = self.with_pos_embed(tgt, query_pos)

        # 正常的transformer 计算 q,k,v   tgt=Value
        tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
        
        # 与之前的memory做 deformable transformer
        tgt2 = self.cross_attn(...)

        # tgt:[1,300,256]
        tgt = self.forward_ffn(...)
        return tgt

4、models/ops/modules/ms_deform_attn.py

class MSDeformAttn(...):
    def __init__(...):
        ...
    
    def forward(...):
        """
        Args:
            query:[1, 17821, 256],位置编码后的多尺度图像特征
            reference_points:[1, 17821, 4, 2] # 每个特征点在四个特征层上的采样点
            input_flatten:[1, 17821, 256],没有位置编码的图像特征
            input_spatial_shapes:特征图尺寸
            input_level_start_index:...
        """
        # 在transformer中,为了得到含有全局建模关系的Value,为每个query计算出一个K、Q、V,利用query的K,Q来计算出的其他特征点相对于自己Value的权重,即当前query的Value是由所有特征点的Value加权求和得到。
        # 在Deformable DETR中,每个特征点先在每层特征点上找到自己坐标对应的采样点,然后在采样点周围采取四个偏移点的Value加权求和作为最终的特征,如果偏移点的坐标不为整数,那么这个偏移点的Value就由偏移点周围的4个特征点的Value双线性插值得到,可以看上图。
        # 学习每个特征点的Value,也就是self-attention中 K、Q、V中的V,不过这里不计算K和Q
        value = self.value_proj(input_flatten) # [1, 17821, 256] 
        value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) # [1,17821,8,32] 8头注意力机制

        # 计算偏移点的相对于ref_point的偏移量
        sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)  # [1,17821,8,4,4,2]
        
        # [1, 17821, 4, 4]  每个value对应的偏移点(4*4)的权重系数
        attention_weights = F.softmax(...)
        
        if reference_points.shape[-1] == 2:
            # 偏移量+参考点坐标 为最终偏移点的坐标并归一化到(0,1)
            sampling_locations = ... 
        else:
            ...
        # 加权求和 cuda加速 
        output = MSDeformAttnFunction.apply(...)
        # 加权求和 pytorch实现,仅供测试
        out1 = ms_deform_attn_core_pytorch(...)
        return output  
        
def ms_deform_attn_core_pytorch(...):
    # for debug and test only
    """
    Args:
        value:[1,17821,8,32],17821个特征点,每个特征点的value有8个head,每个head的长度为32
        value_spatial_shapes:每层特征图的shape,共4层
        sampling_locations:[1, 17821, 8, 4, 4, 2] # 每个head 在每个特征层上学习到的偏移点。
        attention_weights:[1, 17821, 8, 4, 4] # 17821个特征点,每个head 对应偏移点的value的权重系数。
    """
    value_list = ... # [13400,3350,850,221] # 将每层的所有特征点单独提取出来
    sampling_grids = 2 * sampling_locations - 1 #将采样点的坐标 从(0,1)映射到(-1,-1)之间  F.grid_sample 需要
    # for循环里面的注释是lid_ = 0 的时候
    for lid_, (H_, W_) in enumerate(value_spatial_shapes)
        value_l_ = ... # [8,32,134,100] 第lid_层的value
        sampling_grid_l_ = ... # [8,17821,4,2] 所有特征点在 第lid_层value上的偏移点
        sampling_value_l_ = F.grid_sample(..) #因为偏移点一般不为整数,利用F.grid_sample计算最终的偏移点(双线性插值),如第2节中的图所示
    attention_weights = ... # [8,1,17821,16]  每个特征点对应16个偏移点的权重系数
    # 加权求和
    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
    return output.transpose(1, 2).contiguous()  #  [1,17821,256]

4、总结

为了深入理解BEV模型,学习BEV模型的发展是很重要的,向下追根溯源需要理解transformer、vision transformer、DETR。基于DETR的BEV模型后续还有DETE3D、PETR、PETRV2等等,后续希望能够沿着这个主线深入了解各个模型的优缺点。

你可能感兴趣的:(学习,深度学习,计算机视觉)