最近在看DETR的源码,断断续续看了一星期左右,把主要的模型代码理清了。一直在考虑以什么样的形式写一写DETR的源码解析。考虑的一种形式是像之前写的YOLOv5那样的按文件逐行写,一种是想把源码按功能模块串起来。考虑了很久还是决定按第二种方式,一是因为这种方式可能会更省时间,另外就是也方便我整体再理解一下吧。
我觉得看代码就是要看到能把整个模型分功能拆开,最后再把所有模块串起来,这样才能达到事半功倍。
另外一点我觉得很重要的是:拿到一个开源项目代码,要有马上配置环境能够正常运行Debug,并且通过解析train.py马上找到主要模型相关的内容,然后着重关注模型方面的解析,像一些日志、计算mAP、画图等等代码,完全可以不看,可以省很多时间,所以以后我讲解源码都会把无关的代码完全剥离,不再讲解,全部精力关注模型、改进、损失等内容。
这一节主要讲一下DETR的Backbone部分,包括CNN和位置编码两个模块的代码。主要涉及models/backbone.py和models/position_encoding.py两个文件。
Github注释版源码:HuKai97/detr-annotations
整个Backbone主要包括CNN特征提取和位置编码两个部分。代码还是比较简单的,下面开始解析源码。
首先是调用models/Backbone.py中的build_backbone函数创建Backbone:
def build_backbone(args):
# 搭建backbone
# 位置编码 PositionEmbeddingSine()
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0 # 是否需要训练backbone True
return_interm_layers = args.masks # 是否需要返回中间层结果 目标检测False 分割True
# 生成backbone resnet50
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
# 将backbone输出与位置编码相加 0: backbone 1: PositionEmbeddingSine()
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels # 512
return model
这里首先调用build_position_encoding函数生成正余弦位置编码position_embedding:[bs,256,H/32, W/32],其中256前128是y方向位置编码,后128是x方向位置编码;再调用Backbone类生成ResNet50对输入数据进行特征提取得到特征图[bs,2048,H/32, W/32]。最后Joiner将两者合并存储起来,方便后续使用。
创建ResNet50,先调用Backbone类:
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
# 直接掉包 调用torchvision.models中的backbone
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
# resnet50 2048
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
这个类是继承自BackboneBase类的,而且CNN直接调用的就是torchvision.models中的模型,所以直接看BackboneBase类:
class BackboneBase(nn.Module):
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
for name, parameter in backbone.named_parameters():
# layer0 layer1不需要训练 因为前面层提取的信息其实很有限 都是差不多的 不需要训练
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
parameter.requires_grad_(False)
# False 检测任务不需要返回中间层
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
# 检测任务直接返回layer4即可 执行torchvision.models._utils.IntermediateLayerGetter这个函数可以直接返回对应层的输出结果
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor_list: NestedTensor):
"""
tensor_list: pad预处理之后的图像信息
tensor_list.tensors: [bs, 3, 608, 810]预处理后的图片数据 对于小图片而言多余部分用0填充
tensor_list.mask: [bs, 608, 810] 用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)
"""
# 取出预处理后的图片数据 [bs, 3, 608, 810] 输入模型中 输出layer4的输出结果 dict '0'=[bs, 2048, 19, 26]
xs = self.body(tensor_list.tensors)
# 保存输出数据
out: Dict[str, NestedTensor] = {}
for name, x in xs.items():
m = tensor_list.mask # 取出图片的mask [bs, 608, 810] 知道图片哪些区域是有效的 哪些位置是pad之后的无效的
assert m is not None
# 通过插值函数知道卷积后的特征的mask 知道卷积后的特征哪些是有效的 哪些是无效的
# 因为之前图片输入网络是整个图片都卷积计算的 生成的新特征其中有很多区域都是无效的
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]
out[name] = NestedTensor(x, mask)
# out['0'] = NestedTensor: tensors[bs, 2048, 19, 26] + mask[bs, 19, 26]
return out
这个类还是在调用torchvision.models中的模型,然后再把预处理后的图片数据[bs, 3, 608, 810]和mask数据[bs, 608, 810]输入到模型中(这个图片数据是经过pad填充的数据,而mask数据就是记录这些图片哪些像素位置是pad的,为True,没用pad的真实有效数据就为False)。经过前向传播,再调用IntermediateLayerGetter函数把对应层特征图提取出来,得到原图32倍下采样的特征图[bs, 2048, 19, 26],以及这张特征图对应的mask[bs, 19, 26]。
Positional Encoding 就是位置编码。这里主要是调用models/position_encoding.py中的build_position_encoding函数创建位置编码:
def build_position_encoding(args):
"""
创建位置编码
args: 一系列参数 args.hidden_dim: transformer中隐藏层的维度 args.position_embedding: 位置编码类型 正余弦sine or 可学习learned
"""
# N_steps = 128 = 256 // 2 backbone输出[bs,256,25,34] 256维度的特征
# 而传统的位置编码应该也是256维度的, 但是detr用的是一个x方向和y方向的位置编码concat的位置编码方式 这里和ViT有所不同
# 二维位置编码 前128维代表x方向位置编码 后128维代表y方向位置编码
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
# [bs,256,19,26] dim=1时 前128个是y方向位置编码 后128个是x方向位置编码
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding
可以看到,源码是实现了两种位置编码,一种是正余弦绝对位置编码,不需要额外的参数学习,另一种是可学习绝对位置编码。原论文用的是正余弦绝对位置编码,而且代码也是默认使用这个的,所以这里主要介绍PositionEmbeddingSine类:
class PositionEmbeddingSine(nn.Module):
"""
Absolute pos embedding, Sine. 没用可学习参数 不可学习 定义好了就固定了
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats # 128维度 x/y = d_model/2
self.temperature = temperature # 常数 正余弦位置编码公式里面的10000
self.normalize = normalize # 是否对向量进行max规范化 True
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
# 这里之所以规范化到2*pi 因为位置编码函数的周期是[2pi, 20000pi]
scale = 2 * math.pi # 规范化参数 2*pi
self.scale = scale
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors # [bs, 2048, 19, 26] 预处理后的 经过backbone 32倍下采样之后的数据 对于小图片而言多余部分用0填充
mask = tensor_list.mask # [bs, 19, 26] 用于记录矩阵中哪些地方是填充的(原图部分值为False,填充部分值为True)
assert mask is not None
not_mask = ~mask # True的位置才是真实有效的位置
# 考虑到图像本身是2维的 所以这里使用的是2维的正余弦位置编码
# 这样各行/列都映射到不同的值 当然有效位置是正常值 无效位置会有重复值 但是后续计算注意力权重会忽略这部分的
# 而且最后一个数字就是有效位置的总和,方便max规范化
# 计算此时y方向上的坐标 [bs, 19, 26]
y_embed = not_mask.cumsum(1, dtype=torch.float32)
# 计算此时x方向的坐标 [bs, 19, 26]
x_embed = not_mask.cumsum(2, dtype=torch.float32)
# 最大值规范化 除以最大值 再乘以2*pi 最终把坐标规范化到0-2pi之间
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) # 0 1 2 .. 127
# 2i/2i+1: 2 * (dim_t // 2) self.temperature=10000 self.num_pos_feats = d/2
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) # 分母
pos_x = x_embed[:, :, :, None] / dim_t # 正余弦括号里面的公式
pos_y = y_embed[:, :, :, None] / dim_t # 正余弦括号里面的公式
# x方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
# y方向位置编码: [bs,19,26,64][bs,19,26,64] -> [bs,19,26,64,2] -> [bs,19,26,128]
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
# concat: [bs,19,26,128][bs,19,26,128] -> [bs,19,26,256] -> [bs,256,19,26]
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# [bs,256,19,26] dim=1时 前128个是y方向位置编码 后128个是x方向位置编码
return pos
当然作为学习,也可以看看第二种绝对位置编码方式:可学习位置编码:
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
可以发现整个类其实就是初始化了相应shape的位置编码参数,让后通过可学习的方式学习这些位置编码参数
"""
def __init__(self, num_pos_feats=256):
super().__init__()
# nn.Embedding 相当于 nn.Parameter 其实就是初始化函数
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:] # 特征图h w
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i) # 初始化x方向位置编码
y_emb = self.row_embed(j) # 初始化y方向位置编码
# concat x y 方向位置编码
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos
可以发现整个类其实就是初始化了相应shape的位置编码参数,然后通过可学习的方式自己学习这些位置编码参数,代码比较简答。
官方源码: https://github.com/facebookresearch/detr
b站源码讲解: 铁打的流水线工人
知乎【布尔佛洛哥哥】: DETR 源码解读
CSDN【在努力的松鼠】源码讲解: DETR源码笔记(一)
CSDN【在努力的松鼠】源码讲解: DETR源码笔记(二)
CSDN: Transformer中的position encoding(位置编码一)
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(一)、概述与模型推断】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(三)、Backbone与位置编码】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(四)、Detection with Transformer】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(五)、loss函数与匈牙利匹配算法】
知乎CV不会灰飞烟灭-【源码解析目标检测的跨界之星DETR(六)、模型输出与预测生成】