【DETR源码解析】二、Backbone模块

目录

  • 前言
  • 一、Backbone整体结构
  • 一、CNN-Backbone
  • 二、Positional Encoding
  • Reference

前言

最近在看DETR的源码,断断续续看了一星期左右,把主要的模型代码理清了。一直在考虑以什么样的形式写一写DETR的源码解析。考虑的一种形式是像之前写的YOLOv5那样的按文件逐行写,一种是想把源码按功能模块串起来。考虑了很久还是决定按第二种方式,一是因为这种方式可能会更省时间,另外就是也方便我整体再理解一下吧。

我觉得看代码就是要看到能把整个模型分功能拆开,最后再把所有模块串起来,这样才能达到事半功倍。

另外一点我觉得很重要的是:拿到一个开源项目代码,要有马上配置环境能够正常运行Debug,并且通过解析train.py马上找到主要模型相关的内容,然后着重关注模型方面的解析,像一些日志、计算mAP、画图等等代码,完全可以不看,可以省很多时间,所以以后我讲解源码都会把无关的代码完全剥离,不再讲解,全部精力关注模型、改进、损失等内容。

这一节主要讲一下DETR的Backbone部分,包括CNN和位置编码两个模块的代码。主要涉及models/backbone.py和models/position_encoding.py两个文件。

Github注释版源码:HuKai97/detr-annotations

一、Backbone整体结构

整个Backbone主要包括CNN特征提取和位置编码两个部分。代码还是比较简单的,下面开始解析源码。

首先是调用models/Backbone.py中的build_backbone函数创建Backbone:

def build_backbone(args):
    # 搭建backbone
    # 位置编码  PositionEmbeddingSine()
    position_embedding = build_position_encoding(args)
    train_backbone = args.lr_backbone > 0   # 是否需要训练backbone  True
    return_interm_layers = args.masks       # 是否需要返回中间层结果 目标检测False  分割True
    # 生成backbone  resnet50
    backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
    # 将backbone输出与位置编码相加   0: backbone   1: PositionEmbeddingSine()
    model = Joiner(backbone, position_embedding)
    model.num_channels = backbone.num_channels   # 512
    return model

这里首先调用build_position_encoding函数生成正余弦位置编码position_embedding:[bs,256,H/32, W/32],其中256前128是y方向位置编码,后128是x方向位置编码;再调用Backbone类生成ResNet50对输入数据进行特征提取得到特征图[bs,2048,H/32, W/32]。最后Joiner将两者合并存储起来,方便后续使用。

一、CNN-Backbone

创建ResNet50,先调用Backbone类:

class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool):
        # 直接掉包 调用torchvision.models中的backbone
        backbone = getattr(torchvision.models, name)(
            replace_stride_with_dilation=[False, False, dilation],
            pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
        # resnet50  2048
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

这个类是继承自BackboneBase类的,而且CNN直接调用的就是torchvision.models中的模型,所以直接看BackboneBase类:

class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            # layer0 layer1不需要训练 因为前面层提取的信息其实很有限 都是差不多的 不需要训练
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        # False 检测任务不需要返回中间层
        if return_interm_layers:
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {'layer4': "0"}
        # 检测任务直接返回layer4即可  执行torchvision.models._utils.IntermediateLayerGetter这个函数可以直接返回对应层的输出结果
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels

    def forward(self, tensor_list: NestedTensor):
        """
        tensor_list: pad预处理之后的图像信息
        tensor_list.tensors: [bs, 3, 608, 810]预处理后的图片数据 对于小图片而言多余部分用0填充
        tensor_list.mask: [bs, 608, 810] 用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)
        """
        # 取出预处理后的图片数据 [bs, 3, 608, 810] 输入模型中  输出layer4的输出结果 dict '0'=[bs, 2048, 19, 26]
        xs = self.body(tensor_list.tensors)
        # 保存输出数据
        out: Dict[str, NestedTensor] = {}
        for name, x in xs.items():
            m = tensor_list.mask  # 取出图片的mask [bs, 608, 810] 知道图片哪些区域是有效的 哪些位置是pad之后的无效的
            assert m is not None
            # 通过插值函数知道卷积后的特征的mask  知道卷积后的特征哪些是有效的  哪些是无效的
            # 因为之前图片输入网络是整个图片都卷积计算的 生成的新特征其中有很多区域都是无效的
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            # out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]
            out[name] = NestedTensor(x, mask)
        # out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]
        return out

这个类还是在调用torchvision.models中的模型,然后再把预处理后的图片数据[bs, 3, 608, 810]和mask数据[bs, 608, 810]输入到模型中(这个图片数据是经过pad填充的数据,而mask数据就是记录这些图片哪些像素位置是pad的,为True,没用pad的真实有效数据就为False)。经过前向传播,再调用IntermediateLayerGetter函数把对应层特征图提取出来,得到原图32倍下采样的特征图[bs, 2048, 19, 26],以及这张特征图对应的mask[bs, 19, 26]。

二、Positional Encoding

Positional Encoding 就是位置编码。这里主要是调用models/position_encoding.py中的build_position_encoding函数创建位置编码:

def build_position_encoding(args):
    """
    创建位置编码
    args: 一系列参数  args.hidden_dim: transformer中隐藏层的维度   args.position_embedding: 位置编码类型 正余弦sine or 可学习learned
    """
    # N_steps = 128 = 256 // 2  backbone输出[bs,256,25,34]  256维度的特征
    # 而传统的位置编码应该也是256维度的, 但是detr用的是一个x方向和y方向的位置编码concat的位置编码方式  这里和ViT有所不同
    # 二维位置编码   前128维代表x方向位置编码  后128维代表y方向位置编码
    N_steps = args.hidden_dim // 2
    if args.position_embedding in ('v2', 'sine'):
        # TODO find a better way of exposing other arguments
        # [bs,256,19,26]  dim=1时  前128个是y方向位置编码  后128个是x方向位置编码
        position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
    elif args.position_embedding in ('v3', 'learned'):
        position_embedding = PositionEmbeddingLearned(N_steps)
    else:
        raise ValueError(f"not supported {args.position_embedding}")

    return position_embedding

可以看到,源码是实现了两种位置编码,一种是正余弦绝对位置编码,不需要额外的参数学习,另一种是可学习绝对位置编码。原论文用的是正余弦绝对位置编码,而且代码也是默认使用这个的,所以这里主要介绍PositionEmbeddingSine类:

class PositionEmbeddingSine(nn.Module):
    """
    Absolute pos embedding, Sine.  没用可学习参数  不可学习  定义好了就固定了
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats    # 128维度  x/y  = d_model/2
        self.temperature = temperature        # 常数 正余弦位置编码公式里面的10000
        self.normalize = normalize            # 是否对向量进行max规范化   True
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            # 这里之所以规范化到2*pi  因为位置编码函数的周期是[2pi, 20000pi]
            scale = 2 * math.pi  # 规范化参数 2*pi
        self.scale = scale

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors   # [bs, 2048, 19, 26]  预处理后的 经过backbone 32倍下采样之后的数据  对于小图片而言多余部分用0填充
        mask = tensor_list.mask   # [bs, 19, 26]  用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)
        assert mask is not None
        not_mask = ~mask   # True的位置才是真实有效的位置

        # 考虑到图像本身是2维的 所以这里使用的是2维的正余弦位置编码
        # 这样各行/列都映射到不同的值 当然有效位置是正常值 无效位置会有重复值 但是后续计算注意力权重会忽略这部分的
        # 而且最后一个数字就是有效位置的总和,方便max规范化
        # 计算此时y方向上的坐标  [bs, 19, 26]
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        # 计算此时x方向的坐标    [bs, 19, 26]
        x_embed = not_mask.cumsum(2, dtype=torch.float32)

        # 最大值规范化 除以最大值 再乘以2*pi 最终把坐标规范化到0-2pi之间
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)   # 0 1 2 .. 127
        # 2i/2i+1: 2 * (dim_t // 2)  self.temperature=10000   self.num_pos_feats = d/2
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)   # 分母

        pos_x = x_embed[:, :, :, None] / dim_t   # 正余弦括号里面的公式
        pos_y = y_embed[:, :, :, None] / dim_t   # 正余弦括号里面的公式
        # x方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        # y方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        # concat: [bs,19,26,128][bs,19,26,128] -> [bs,19,26,256] -> [bs,256,19,26]
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)

        # [bs,256,19,26]  dim=1时  前128个是y方向位置编码  后128个是x方向位置编码
        return pos

对照公式:
【DETR源码解析】二、Backbone模块_第1张图片
我的几个关键点的理解:

  1. 这里是通过mask来构建位置编码的,mask中记录了特征图中每个像素位置是否是pad的,只有在为False的位置,才是有效的位置,才需要构建位置编码;
  2. 关于最大值规范化 :因为正余弦编码方式,思想就是将各个位置的通过公式映射到 0~2Π 这个范围内(也可以是4Π,6Π,8Π…,因为它是一个周期函数,不过我们一般默认为2Π),所以这里在带入公式之前需要对x_embed、y_embed先进行规范化;
  3. 关于位置编码方式:这里之所以是把x和y分别进行位置编码(二维位置编码),而不是像transformer那样的一维位置编码。主要考虑的是transformer是应用在语言模型中的,天然就是一维的,所以一维可能更适合,而DETR是应用在图像任务中的一个目标检测框架,在图像任务中,当然二维位置编码效果可能会更好点;
  4. 这样,对于每个位置(x,y),其所在列对应的编码值就在通道维度的前128维,其所在行的编码值就在通道这个维度的后128维。这样这个特征图上各个位置就都对应到不同的维度的编码值了。

当然作为学习,也可以看看第二种绝对位置编码方式:可学习位置编码:

class PositionEmbeddingLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    可以发现整个类其实就是初始化了相应shape的位置编码参数,让后通过可学习的方式学习这些位置编码参数
    """
    def __init__(self, num_pos_feats=256):
        super().__init__()
        # nn.Embedding  相当于 nn.Parameter  其实就是初始化函数
        self.row_embed = nn.Embedding(50, num_pos_feats)
        self.col_embed = nn.Embedding(50, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, tensor_list: NestedTensor):
        x = tensor_list.tensors
        h, w = x.shape[-2:]   # 特征图h w
        i = torch.arange(w, device=x.device)
        j = torch.arange(h, device=x.device)
        x_emb = self.col_embed(i)   # 初始化x方向位置编码
        y_emb = self.row_embed(j)   # 初始化y方向位置编码
        # concat x y 方向位置编码
        pos = torch.cat([
            x_emb.unsqueeze(0).repeat(h, 1, 1),
            y_emb.unsqueeze(1).repeat(1, w, 1),
        ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
        return pos

可以发现整个类其实就是初始化了相应shape的位置编码参数,然后通过可学习的方式自己学习这些位置编码参数,代码比较简答。

Reference

官方源码: https://github.com/facebookresearch/detr

b站源码讲解: 铁打的流水线工人

知乎【布尔佛洛哥哥】: DETR 源码解读

CSDN【在努力的松鼠】源码讲解: DETR源码笔记(一)

CSDN【在努力的松鼠】源码讲解: DETR源码笔记(二)

CSDN: Transformer中的position encoding(位置编码一)

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(一)、概述与模型推断】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(三)、Backbone与位置编码】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(四)、Detection with Transformer】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(五)、loss函数与匈牙利匹配算法】

知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(六)、模型输出与预测生成】

你可能感兴趣的:(#,Transformer,Based,Cls&Det,detr,transformer)