【代码实现】DETR原文解读及代码实现细节

1 模型总览

【代码实现】DETR原文解读及代码实现细节_第1张图片

宏观上来说,DETR主要包含三部分:以卷积神经网络为主的骨干网(CNN Backbone)、以TRM(Transformer)为主的特征抽取及交互器以及以FFN为主的分类和回归头,如DETR中build()函数所示。DETR最出彩的地方在于,它摒弃了非端到端的处理过程,如NMS、anchor generation等,以集合预测的方式来端到端建模目标检测过程,并且将Transformer引入到目标检测中,打开新领域的大门)。

def build(args):

    backbone = build_backbone(args)
    transformer = build_transformer(args)

    model = DETR(
        backbone,# 骨干网
        transformer,# 重点部分
        num_classes=81,
        num_queries=100,# object query数量,作用相当于spatial embedding
        aux_loss=args.aux_loss)
    matcher = build_matcher(args)# 二分图匹配
    weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
    weight_dict['loss_giou'] = args.giou_loss_coef
    losses = ['labels', 'boxes', 'cardinality']
    postprocessors = {'bbox': PostProcess()}

    return model, criterion, postprocessors

2 DETR的基本流程

  • CNN backbone 提取图像的 feature
  • Transformer Encoder 通过 self-attention 建模全局关系对 feature 进行增强
  • Transformer Decoder 的输入是 object queries(spatial embedding) 和 Transformer encoder 的输出(content embedding),主要包含 self-attention 和 cross-attention 的过程。Self-attention 主要是对每个 query 之间做交互,让每个 query 能看到其他 query 在查询什么东西,从而不重复,类似与 NMS 的作用;Cross- attention 主要是将 object query 当做查询,encoder feature 当做 key,为了查询和 query 有关的区域。
  • 对 Decoder 输出的查询好了的 query,使用 FFN 提取出目标框的位置和类别信息

顺着上边的基本流程,从代码入手一点点理解原文的思想,下面开始!

3 backbone

首先是构建backbone模块的函数

def build_backbone(args):

    position_embedding = build_position_encoding(args)
    backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
    model = Joiner(backbone, position_embedding)
    model.num_channels = backbone.num_channels
    return model

3.1 根据特征图生成位置编码

def build_position_encoding(args):
    N_steps = args.hidden_dim // 2
    if args.position_embedding in ('v2', 'sine'):
        # TODO find a better way of exposing other arguments
        position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
    elif args.position_embedding in ('v3', 'learned'):
        position_embedding = PositionEmbeddingLearned(N_steps)
    else:
        raise ValueError(f"not supported {args.position_embedding}")

    return position_embedding

对于输入特征x,假定其尺寸为BxCxHxW,位置编码需要在H W两个维度上进行位置编码,所以一般会将hidden_dim(C, 通道)切分为两部分,一部分代表H另一部分代表W,最后在通道维度上进行拼接。DETR中位置编码主要是sine和learning position embedding。我自己仿照TRM写了一个可学习位置编码的实现,可以运行试试

# 验证learnable pos embedding机制
x = torch.randn((8, 3, 32, 32))
h,w=x.shape[-2:]
row_embed,col_embed=nn.Embedding(50,256),nn.Embedding(50,256)

i,j=torch.arange(w,device=x.device),torch.arange(h,device=x.device)
x_emb,y_emb=col_embed(i),row_embed(j)
x_cat=x_emb.unsqueeze(0).repeat(h,1,1)
y_cat=y_emb.unsqueeze(1).repeat(1,w,1)

pos=torch.cat([x_cat,y_cat],dim=-1)
pos_learn=pos.permute(2,0,1).unsqueeze(0).repeat(x.shape[0],1,1,1)# shape:(8,512,32,32)

构建backbone

这部分不太难,相关注释已经写在代码块中

class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:# 是否返回中间层,在多尺度融合操作时会用到
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:# 一般情况下只返回最后一层的输出
            return_layers = {'layer4': "0"}
        # IntermediateLayerGetter作用类似于Sequential,将多个神经层组合并可以指定返回中间输出
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)# 将数据传入网络,实例化网络得到输出,xs即为经过resnet四部分后的输出
        out: Dict[str, NestedTensor] = {}# 定义输出格式
        for name, x in xs.items():# 如果返回中间层,out可以按照name存储,返回最后一层则只有layer4
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]# 利用插值法产生不同尺度下的mask
            out[name] = NestedTensor(x, mask)
        return out

将上述二者对应组合

class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)# [0]代表backbone的输出
        out: List[NestedTensor] = []
        pos = []
        for name, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self[1](x).to(x.tensors.dtype))# [1]代表position_embedding

        return out, pos# 返回抽取后的特征及对应的位置编码

Transformer

【代码实现】DETR原文解读及代码实现细节_第2张图片
TRM部分其实跟Attention is all you need的模型结构完全相同,不同的部分只是decoder部分输入。在原始transformer中,decoder的输入是对应目标序列融合位置编码后的embedding,而在本文中,则使用初始化为全0的tgt作为目标序列,然后再融合query_embed。这里非常容易混淆的一点是:tgt全零序列才是content embedding, 代码中的query_embed是代表目标框集合位置的spatial embedding
由于encoder部分主要是特征提取,对边界定位影响不大,所以我们考虑decoder的cross-attention部分,其输入主要包括三部分:query key value

  • queries:每个 query 都是 decoder 第一层 self-attention 的输出( content query )+ object query( spatial query ),这里的 object query 就是 DETR 中提出的概念,每个 object query 都是候选框的信息,经过 FFN 后能输出位置和类别信息(本文 object query 个数 N 为 100)
  • keys:每个 key 都是 encoder 的输出特征( content key ) + 位置编码( spatial key )构成
  • values:只有来自 encoder 的输出
class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,return_intermediate=return_intermediate_dec)

        self.d_model = d_model
        self.nhead = nhead

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)# flatten(k)表示将[k:n-1]拉平为一个维度
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)# sine
        # num_queries x hidden_dim to num_queries x N x hidden_dim
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
        # NxHxW to NxHW
        mask = mask.flatten(1)

        # decoder embedding,初始化为全0
        tgt = torch.zeros_like(query_embed)
        # encoder特征抽取,得到memory,shape同tgt
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        # decoder特征交互,得到hs
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

DETR

class DETR(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):

        super().__init__()
        self.num_queries = num_queries# object query nums
        self.transformer = transformer
        hidden_dim = transformer.d_model# 隐层维度
        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)# 分类头,最后的类别为:类别数+背景
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)# 检测头,使用三层全连接层进行映射,最后投影到xywh
        self.query_embed = nn.Embedding(num_queries, hidden_dim)# object query
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)# 将得到的feature map通道归一
        self.backbone = backbone
        self.aux_loss = aux_loss

    def forward(self, samples: NestedTensor):

        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        # 得到的feature可能是C3-C5几层,DETR只拿最后一层输入TRM
        src, mask = features[-1].decompose()
        assert mask is not None
        # self.transformer()[0]表示取dncoder的输出,序列1表示encoder输出
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

        outputs_class = self.class_embed(hs)
        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out

你可能感兴趣的:(文献阅读,DETR,深度学习,pytorch,神经网络,transformer)