在上一篇中介绍了BEVFormer的大体流程,地址为:https://zhuanlan.zhihu.com/p/593998659,由于本项目中涉及到许多变量重复且变量名重复使用,导致在代码阅读中会有一定的难度,本文注重在给关键的变量进行注释,下文中的内容仅仅我个人的理解,如果有错误的地方,烦请各位大佬说明并进行改正。
本人也是初学者,欢迎正在学习或者想学习BEV模型的朋友加入交流群一起讨论、学习论文或者代码实现中的问题 ,可以加 Rex1586662742,q群:468713665
本文依旧是按照forward的过程对变量进行说明。
1、tools/test.py
outputs = custom_multi_gpu_test(model, data_loader, args.tmpdir,args.gpu_collect)
# 进入到projects/mmdet3d_plugin/bevformer/apis/test.py
2、projects/mmdet3d_plugin/bevformer/apis/test.py
def custom_multi_gpu_test(...):
...
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
# 进入到 projects/mmdet3d_plugin/bevformer/detectors/bevformer.py
...
3、projects/mmdet3d_plugin/bevformer/detectors/bevformer.py
def forward(...):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
# 进入到 self.forward_test 中
def forward_test(...):
...
# forward
new_prev_bev, bbox_results = self.simple_test(...)
...
def simple_test(...):
# self.extract_feat 主要包括两个步骤 img_backbone、img_neck,通过卷积提取特征
# 网络为resnet + FPN
# 如果是base模型,img_feats 为四个不同尺度的特征层
# 如果是small、tiny,img_feats 为一个尺度的特征层
img_feats = self.extract_feat(img=img, img_metas=img_metas)
# Temproral Self-Attention + Spatial Cross-Attention
new_prev_bev, bbox_pts = self.simple_test_pts(
img_feats, img_metas, prev_bev, rescale=rescale)
def simple_test_pts(...):
# 对特征层进行编解码
outs = self.pts_bbox_head(x, img_metas, prev_bev=prev_bev)
# 进入到 projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py
4、projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py
class BEVFormerHead(DETRHead):
def __init__(...):
if not self.as_two_stage:
# 可学习的位置编码
self.bev_embedding = nn.Embedding(self.bev_h * self.bev_w, self.embed_dims)
self.query_embedding = nn.Embedding(self.num_query,self.embed_dims * 2)
def forward(...):
'''
mlvl_feats: (tuple[Tensor]) FPN网络输出的多尺度特征
prev_bev: 上一时刻的 bev_features
all_cls_scores: 所有的类别得分信息
all_bbox_preds: 所有预测框信息
'''
# 特征编码 (900,512) (900,256) concate (900 + 256)
object_query_embeds = self.query_embedding.weight.to(dtype)
# [2500,256] bev特征图的大小,最终bev的大小为 50*50,每个点的channel维度为256。(base模型的特征图大小为200 * 200)
bev_queries = self.bev_embedding.weight.to(dtype)
# [1,50,50] 每个特征点对应一个mask点
bev_mask = torch.zeros((bs, self.bev_h, self.bev_w), device=bev_queries.device).to(dtype)
# [1, 256, 50, 50] 可学习的位置编码
bev_pos = self.positional_encoding(bev_mask).to(dtype)
if only_bev:
...
else:
# mlvl_feats ,多尺度特征
# bev_queries ,200*200,256
# object_query_embeds = 900 * 512 # 检测头使用的部分
outputs = ...
outputs = self.transformer(...)
# 进入到 projects/mmdet3d_plugin/bevformer/modules/transformer.py
for lvl in range(hs.shape[0]):
# 类别
outputs_class = self.cls_branches[lvl](hs[lvl])
# 回归框信息
tmp = self.reg_branches[lvl](hs[lvl])
5、projects/mmdet3d_plugin/bevformer/modules/transformer.py
class PerceptionTransformer(...):
def __init__(...):
...
def forward(...):
# 获得bev特征 temporal_self_attention + spatial_cross_attention
bev_embed = self.get_bev_features(...)
def get_bev_features(...):
# 车身底盘信号:速度、加速度等
# 当前帧的bev特征与历史特征进行 时间、空间上的对齐
delta_x = ...
# BEV特征中 每一格 在真实世界中对应的长度
grid_length_x = 0.512
grid_length_x = 0.512
# 上帧和当前帧的偏移量
shift_x = ...
shift_y = ...
if prev_bev is not None:
...
if self.rotate_prev_bev:
# 车身旋转角度
rotation_angle = ...
# can信号映射到 256维度
can_bus = self.can_bus_mlp(can_bus)[None, :, :]
# bev特征加上can_bus特征
bev_queries = bev_queries + can_bus * self.use_can_bus
# sca 有关
for lvl, feat in enumerate(mlvl_feats):
# 特征编码
if self.use_cams_embeds:
feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to(feat.dtype)
# 每一个维度的起始点
level_start_index = ...
# 获得bev特征 block * 6
bev_embed = self.encoder(...)
# 进入到projects/mmdet3d_plugin/bevformer/modules/encoder.py
...
# decoder
inter_states, inter_references = self.decoder(...)
# 进入到 projects/mmdet3d_plugin/bevformer/modules/decoder.py 中
return bev_embed, inter_states, init_reference_out, inter_references_out
# 返回到projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py
6、projects/mmdet3d_plugin/bevformer/modules/encoder.py
class BEVFormerEncoder(...):
def __init__(self):
...
def get_reference_points(...):
'''
获得参考点用于 SCA以及TSA
H:bev_h
W:bev_w
Z:pillar的高度
num_points_in_pillar:4,在每个pillar里面采样四个点
'''
# SCA
if dim == '3d':
# (4, 50, 50) 为每一个bev_query特征点在0~Z上均匀采样4个点,并归一化
zs = ...
# 均匀采样的x坐标
xs = ...
# 均匀采样的y坐标
ys = ...
# (1, 4, 2500, 3)
ref_3d =
# TSA
elif dim == '2d':
# bev特征点坐标
ref_2d = ...
def point_sampling(...)
'''
pc_range: bev特征表征的真实的物理空间大小
img_metas: 数据集 list [(4*4)] * 6
'''
# 4×4 为 雷达坐标系转图像坐标系的齐次矩阵
# 采用lidar 的坐标系
lidar2img = ...
# 参考坐标转化的尺度转化为真实尺度
# [x, y, z, 1]
reference_points = ...
# (4,4) * [x,y,z,1] -> (zc * u , zc * v, zc, 1) 像素空间
reference_points_cam = torch.matmul(lidar2img.to(torch.float32), reference_points.to(torch.float32)).squeeze(-1)
# 通过阈值判断,对bev_query的每个坐标进行 #判断,高于阈值的为True,否则为False,用于减少计算量
# zc 大于 eps 的 为true
bev_mask = (reference_points_cam[..., 2:3] > eps)
# 0~1之间
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
# 确保所有点在正确范围内
bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
& (reference_points_cam[..., 1:2] < 1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 0:1] > 0.0))
...
# 先进入到这个forward
@auto_fp16()
def forward(...):
'''
bev_query: (2500, 1, 256)
key: (6, 375, 1, 256) 6个相机图片的特征
value: 与key一致
bev_pos:(2500, 1, 256) 为每个bev特征点进行可学习的编码
spatial_shapes: 相机特征层的尺度,tiny模型只有一个,base模型有4个
level_start_index: 特征尺度的索引
prev_bev:(2500, 1, 256) 前一时刻的bev_query
shift: 当前bev特征相对于上一时刻bev特征的偏移量
'''
# z轴的采样点坐标 (1, 4, 2500, 3)
ref_3d = self.get_reference_points(...)
# bev_query 特征点的归一化坐标 (1, 2500, 1, 2)
ref_2d = self.get_reference_points(...)
# (6,1,40000,4,2) 像素坐标
reference_points_cam, bev_mask = self.point_sampling(...)
# 当前bev特征坐标等于上一时刻bev特征+偏移量
# 通过偏移量,可以将当前帧的bev特征点与上一帧的bev特征点联系起来
shift_ref_2d += shift[:, None, None, :]
if prev_bev is not None:
# 叠加当前时刻bev_query 和上一时刻的bev_query
prev_bev = torch.stack([prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1)
# 6 × encoder
for lid, layer in enumerate(self.layers):
# 进入到下面的 BEVFormerLayer 的forward中
output = layer(...)
class BEVFormerLayer(MyCustomBaseTransformerLayer)
def __init__(...):
'''
attn_cfgs:来自总体网络配置文件的参数
ffn_cfgs:单层神经网络的参数
operation_order: 'self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm',encode中每个block中包含的步骤
'''
# 注意力模块的个数 2
self.num_attn
# 编码维度 256
self.embed_dims
# ffn层
self.ffns
# norn层
self.norms
...
def forward(...):
'''
query:当前时刻的bev_query,(1, 2500, 256)
key: 当前时刻6个相机的特征,(6, 375, 1, 256)
value:当前时刻6个相机的特征,(6, 375, 1, 256)
bev_pos:每个bev_query特征点 可学习的位置编码
ref_2d:前一时刻和当前时刻bev_query对应的参考点 (2, 2500, 1, 2)
red_3d: 当前时刻在Z轴上的采样的参考点 (1, 4, 2500, 3) 每个特征点在z轴沙漠化采样4个点
bev_h: 50
bev_w: 50
reference_points_cam: (6, 1, 2500, 4, 2)
spatial_shapes:FPN特征层大小 [15,25]
level_start_index: [0] spatial_shapes对应的索引
prev_bev: 上上个时刻以及上个时刻 bev_query(2, 2500, 256)
'''
# 遍历六个 encoder的 block块
for layer in self.operation_order:
# 首先进入tmporal_self_attention
if layer == 'self_attn':
# self.attentions 为 temporal_self_attention模块
query = self.attentions[attn_index]
# 进入到projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
# Spatial Cross-Attention
# 然后进入 Spatial Cross-Attention
elif layer == 'cross_attn':
query = self.attentions[attn_index]
# 进入到 projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
7、projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
class TemporalSelfAttention(...):
def __init__(...):
'''
embed_dims: bev特征维度 256
num_heads: 8 头注意力
num_levels:1 多尺度特征的层数
num_points:4,每个特征点采样四个点进行计算
num_bev_queue:bev特征长度,及上一时刻以及当前时刻
'''
self.sampling_offsets = nn.Linear(...) # 学习偏置的网络
self.attention_weights =nn.Linear(...) # 学习注意力特征的网络
self.value_proj = nn.Linear(...) # 学习vaule特征的网络
self.output_proj = nn.Linear(...) # 输入结果的网络
def forward(...):
'''
query: (1, 2500, 256) 当前时刻的bev特征图
key: (2, 2500, 256) 上一个时刻的以及上上时刻的bev特征
value: (2, 2500, 256) 上一个时刻的以及上上时刻的bev特征
query_pos: 可学习的位置编码
reference_points:每个bev特征点对应的坐标
'''
# 初始帧
if value is None:
assert self.batch_first
bs, len_bev, c = query.shape
value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c)
# 位置编码
if query_pos is not None:
query = query + query_pos
# 将前一时刻的bev和当前时刻的bev特征进行叠加
query = torch.cat([value[:bs], query], -1)
# 学习前一时刻和当前时刻的bev特征 (1, 2500, 128)
value = self.value_proj(value)
# 8 个头的注意力
value = value.reshape(bs*self.num_bev_queue,
num_value, self.num_heads, -1)
# (1, 2500, 128)
# 从当前时刻的bev_query 学习到 参考点的偏置
sampling_offsets = self.sampling_offsets(query)
# (1, 2500, 8, 2, 1, 4, 2)
sampling_offsets = sampling_offsets.view(
bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels, self.num_points, 2)
# (1, 2500, 8, 2, 4) 用于学习每个特征点之间的权重
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels * self.num_points)
# offset_normalizer = (50,50)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
# reference_points (2, 2500, 1, 2) pre_bev 和当前bev 每个特征点的 归一化坐标 0~1之间
# sampling_locations bev上每个特征点与哪些采样点进行注意力计算
sampling_locations = reference_points [][:, :, None, :, None, :] + sampling_offsets + offset_normalizer[None, None, None, :, None, :]
if ...:
...
else:
# 计算deformable attention output (2, 2500, 256)
output = multi_scale_deformable_attn_pytorch(...)
# (2500, 256, 1, 2) 当前时刻与上个时刻的注意力特征
output = output.view(num_query, embed_dims, bs, self.num_bev_queue)
# 将两个时刻的注意力特征取平均值
output = output.mean(-1)
# 线性层
output = self.output_proj(output)
# 残差链接
return self.dropout(output) + identity
返回到 projects/mmdet3d_plugin/bevformer/modules/encoder.py 中
def multi_scale_deformable_attn_pytorch(...):
# 映射到 -1 到 1之间
sampling_grids = 2 * sampling_locations - 1
for level, (H_, W_) in enumerate(value_spatial_shapes):
# 不规则采样
sampling_value_l_ = F.grid_sample
# 相乘注意力操作
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
8、projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
SpatialCrossAttention(...):
def __init__(...):
'''
embed_dims:编码维度
pc_range:真实世界的尺度
deformable_attention: 配置参数
num_cams:相机数量
'''
self.output_proj = nn.Linear(...) # out网络
def forward(...):
'''
query:tmporal_self_attention的输出加上 self.norms
reference_points:(1, 4, 2500, 3) 由 tmporal_self_attention的输出加上 模块计算的z轴上采样点的坐标,每个bev特征的有三个坐标点(x,y,z)
bev_mask:(6, 1, 2500, 4) 某些特征点的值为false,可以将其过滤掉,2500为bev特征点个数,1为特征尺度,4,为在每个不同尺度的特征层上采样点的个数。
'''
# (6, 375, 1, 256) query 轮巡到 key 上查找特征
# bev_mask.shape (6, 1, 2500, 4)
for i, mask_per_img in enumerate(bev_mask):
# 从每个特征层上找到有效位置的 index
index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
indexes.append(index_query_per_img)
# bev特征层对应每个 相机特征的 最大的特征数的长度
max_len = max([len(each) for each in indexes])
# 将所有 相机的特征点的个数 重建为 最大特征长度
queries_rebatch = query.new_zeros([bs, self.num_cams, max_len, self.embed_dims])
# 将query放到 reference_points_rebatch中
reference_points_rebatch = ...
for j in range(bs):
for i, reference_points_per_img in enumerate(reference_points_cam):
# 将query和 reference_points_cam 中有效的元素提取出来
...
# deformable_attention
queries = self.deformable_attention(...)
# self.deformable_attention
class MSDeformableAttention3D(BaseModule):
def __init__(...):
'''
embed_dims:编码维度
num_heads:注意力头数
num_levels: 4
每个z轴上的点要到每一个相机特征图上寻找两个点,所以会有8个点
'''
# 学习特征点偏移的网络
self.sampling_offsets = nn.Linear(...)
# 提取特征网络
self.attention_weights(...)
# 输出特征网络
self.value_proj = nn.Linear(...)
def forward(...):
'''
query: (1,604,256), queries_rebatch 特征筛选过后的query
query_pos:挑选的特征点的归一化坐标
'''
# mlp
value = self.value_proj(value)
value = value.view(bs, num_value, self.num_heads, -1)
# 从bev_query 学习到的偏置
sampling_offsets = ...
# 注意力权重
attention_weights
...
if ...:
else:
output = ...
...
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
...
返回到encoder.py中
9、projects/mmdet3d_plugin/bevformer/modules/decoder.py
class DetectionTransformerDecoder(...):
def __init__(...):
...
def forward(...):
'''
query: [900,1,256] bev 特征
reference_points: [1, 900, 3] 每个query 对应的 x,y,z坐标
'''
# 重复6次decoder
for lid, layer in enumerate(self.layers):
# 取x,y
reference_points_input = reference_points[..., :2].unsqueeze(2)
output = layer(...)
# 进入到 CustomMSDeformableAttention
# 在获得查询到的特征后,会利用回归分支(FFN 网络)对提取的特征计算回归结果,预测 10 个输出
# (xc,yc,w,l,zc,h,rot.sin(),rot.cos(),vx,vy);[预测框中心位置的x方向偏移,预测框中心位置的y方向偏移,预测框的宽,预测框的长,预测框中心位置的z方向偏移,预测框的高,旋转角的正弦值,旋转角的余弦值,x方向速度,y方向速度]
# 然后根据预测的偏移量,对参考点的位置进行更新,为级联的下一个 Decoder 提高精修过的参考点位置
new_reference_points = torch.zeros_like(reference_points)
# 预测出来的偏移量是绝对量
# 框中心处的 x, y 坐标
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points[..., :2])
# 框中心处的 z 坐标
new_reference_points[..., 2:3] = tmp[..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])
# 计算归一化坐标
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
return output, reference_points
# 返回到 projects/mmdet3d_plugin/bevformer/modules/transformer.py
class CustomMSDeformableAttention(...):
def forward(...):
'''
query: (900, 1, 256)
query_pos:(900, 1, 256) 可学习的位置编码
'''
output = multi_scale_deformable_attn_pytorch(...)
output = self.output_proj(output)
return self.dropout(output) + identity
损失函数的计算在https://zhuanlan.zhihu.com/p/543335939中讲的比较详细了,因此本文不再进行叙述,通过对BEVFoer论文以及代码的阅读,基本上弄清楚了工作流程,主要是弄清楚了TSA、SCA是如何实现的,这是笔者详细了解的第一个BEV模型,细节上可能还会有些问题,但BEV模型还在不断更新,不得不去卷其他模型了。