本文主要是学习整理,结合DETR3D的模型结构与MMDetection3D的模型构建方法,首先介绍model dict的模型参数设置,然后介绍逐个介绍DETR3D中的子结构,过程中简单讲解mmdetection3d的模型构建流程。
model部分:定义按照backbone,neck,head的顺序设置模型参数。
# 此处省略关键参数,实际以具体的配置文件为准
model = dict(
type='Detr3D',
use_grid_mask=True,
# resnet提取0,1,2,3层的特征
img_backbone=dict(),
img_neck=dict(),
# transformer head定义,本层的dict所指代的类负责对包含在内的 下一层dict实体 进行实例化
pts_bbox_head=dict(
type='Detr3DHead',
# head中只有decoder
transformer=dict(),
# loss,bbox,position_embedding
bbox_coder=dict(),
positional_encoding=dict(),
loss_cls=dict(),
train_cfg=dict()))
)
MMDetection3D利用类之间的包含关系(head中包含transformer, transformer中包含decoder等)递归实例化每个组件, 在build_model后,通过registry这种注册机制,递归地实例化每个registry model。
具体如何初始化呢? 编者在第一次看源码时也遇到了问题,框架的抽象程度很高,但是逐步推进到底层源码,了解registry的注册、调用、初始化方式,可以清楚了解整个流程,这里以transformer与decoder为例:
@TRANSFORMER.register_module()
class Detr3DTransformer(BaseModule):
def __init__(self,
num_feature_levels=4,
num_cams=6,
two_stage_num_proposals=300,
decoder=None,
**kwargs):
super(Detr3DTransformer, self).__init__(**kwargs)
# 初始化decoder
self.decoder = build_transformer_layer_sequence(decoder)
def build_from_cfg(cfg, registry, default_args=None):
# obj_type:transformer
obj_type = args.pop('type')
if isinstance(obj_type, str):
# get registry for dataset
# 查询并获得registry注册好的decoder类
obj_cls = registry.get(obj_type)
return obj_cls
总结来说:
img_backbone=dict(
type='ResNet',
# resnet101
depth=101,
# bottom-up结构特征图的C0,1,2,3
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN2d', requires_grad=False),
norm_eval=True,
style='caffe',
dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, False, True, True)),
img_neck=dict(
type='FPN',
# FPN的输入channel
in_channels=[256, 512, 1024, 2048],
# 最终的四个特征图都是256维
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=4,
relu_before_extra_convs=True)
head继承自mmdet3d提供的DetrHead
pts_bbox_head=dict(
type='Detr3DHead',
num_query=900,
num_classes=10,
in_channels=256,
sync_cls_avg_factor=True,
with_box_refine=True,
as_two_stage=False,
# head中只有decoder
transformer=dict(),
# loss,bbox,position_embedding
bbox_coder=dict(
type='NMSFreeCoder',
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
pc_range=point_cloud_range,
max_num=300,
voxel_size=voxel_size,
num_classes=10),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True,
offset=-0.5),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
loss_bbox=dict(type='L1Loss', loss_weight=0.25),
loss_iou=dict(type='GIoULoss', loss_weight=0.0))
最底层的部分,完成了论文中的主要创新点部分:
transformer=dict(
type='Detr3DTransformer',
decoder=dict(
type='Detr3DTransformerDecoder',
num_layers=6,
return_intermediate=True,
# 设置单个decoder layer参数
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(
type='Detr3DCrossAtten',
pc_range=point_cloud_range,
num_points=1,
embed_dims=256)
],
feedforward_channels=512,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm','ffn', 'norm'))))
负责DETR3D的关键部分:reference points,特征抓取,queries refinement,objects cross attention
@TRANSFORMER.register_module()
class Detr3DTransformer(BaseModule):
def forward(self,
mlvl_feats,
query_embed,
reg_branches=None,
**kwargs):
"""
mlvl_feats (list(Tensor)): [bs, embed_dims, h, w].
query_embed (Tensor): [num_query, c].
mlvl_pos_embeds (list(Tensor)): [bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads
with_box_refine
"""
bs = mlvl_feats[0].size(0)
# 256 -> 128, 128
query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
# -1为保持原样
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
# query_pos作为输入通过reg_branches回归参考点对应的2d position
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid()
init_reference_out = reference_points
# decoder
query = query.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
# decoder
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=mlvl_feats,
query_pos=query_pos,
reference_points=reference_points,
reg_branches=reg_branches,
**kwargs)
inter_references_out = inter_references
return inter_states, init_reference_out, inter_references_out
decoder部分关键在于如何完成论文中提出的object queries refinement,这里着重进行介绍:
Decoder Block
每一个decoder block流程:预测上一层queries对应的reference points后对queries进行refinement后,进行self-attention,作为下一个block输入:self.dropout(output) + inp_residual + pos_feat,即输出=原始输入+双线性插值特征+query位置特征
如何对提取后的多尺度特征进行处理呢?
这里的提取的图像特征,从shape=(bs, c, num_query, num_cam, 1, len(num_feature_level))到shape=(bs, c, num_query),通过三个连续的sum(-1),将不同视角的相机特征,不同尺度的相机特征,进行求和,得到最终的图像特征,然后通过project将图像特征投影到与query同维度,最后直接求和作为下一个Decoder Block的输入。
output = output.sum(-1).sum(-1).sum(-1)
@ATTENTION.register_module()
class Detr3DCrossAtten(BaseModule):
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
query = query.permute(1, 0, 2)
bs, num_query, _ = query.size()
attention_weights = self.attention_weights(query).view(
bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)
# 双线性插值
reference_points_3d, output, mask = feature_sampling(
value, reference_points, self.pc_range, kwargs['img_metas'])
output = torch.nan_to_num(output)
mask = torch.nan_to_num(mask)
attention_weights = attention_weights.sigmoid() * mask
output = output * attention_weights
output = output.sum(-1).sum(-1).sum(-1) # sum后缩减三个维度:shape:[bs, c, num_query]
output = output.permute(2, 0, 1) # [num_query, bs, c]
output = self.output_proj(output) # (num_query, bs, embed_dims),将reference3d的dim转换到256
# output作为fetch的feature,与经过encoder后的query、原始query直接相加作为refinement query
pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2)
return self.dropout(output) + inp_residual + pos_feat
# 特征采样部分
# 特征采样部分, Input queries from different level. Each element has shape [bs, embed_dims, h, w] 也就是[4, bs, embed_dims, h, w]
def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
lidar2img = []
# lidar2img:3D坐标以lidar为中心,求出3D点到img的转换关系也就是求出lidar到img的转换关系
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
lidar2img = np.asarray(lidar2img)
# N = 6,referrence_points:[bs, num_query, 3]
lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
reference_points = reference_points.clone()
reference_points_3d = reference_points.clone()
# recompute top-left(x,y) and bottom-right(x)
reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0]
reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1]
reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2]
# reference_points [bs, num_query, 3]
reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
B, num_query = reference_points.size()[:2]
# num_cam = 6
num_cam = lidar2img.size(1)
# from [b,1,num_query,4] to [b,num_cam,num_query, 4, 1]
reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
# shape:[b, num_cam, num_query, 4, 4]
lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1)
# project 3d -> 2d
# shape:[b, num_cam, num_query, 4]
reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
eps = 1e-5
mask = (reference_points_cam[..., 2:3] > eps)
# cam坐标归一化: reference_points_cam.shape:[b,num_cam,num_query,2]
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps)
# 0,1分别代表camera像素坐标系下的x,y坐标,并进行归一化
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
reference_points_cam = (reference_points_cam - 0.5) * 2
mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 1:2] > -1.0)
& (reference_points_cam[..., 1:2] < 1.0))
mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5)
mask = torch.nan_to_num(mask)
sampled_feats = []
# 对四个特征层分别求出线性插值后的feature,其中N为num_query, [4, bs, embed_dims, h, w]
for lvl, feat in enumerate(mlvl_feats):
B, N, C, H, W = feat.size() # (num_key, bs, embed_dims)
# N=num_cam
feat = feat.view(B*N, C, H, W)
# [b,num_cam,num_query,2] -> [b, num_cam, num_query, 1, 2]
reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
# F.grid_sample return:[b*n,c,num_query,1]每个query对应着一个grid采样(bilinear incorparation)后返回的值
sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
# b,c,n_q,n,1
sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
sampled_feats.append(sampled_feat)
# [b,n,c,num_query,len(mlvl_feats)]
sampled_feats = torch.stack(sampled_feats, -1)
sampled_feats = sampled_feats.view(B, C, num_query, num_cam, 1, len(mlvl_feats))
return reference_points_3d, sampled_feats, mask
关于F.grid_sample()