DETR系列中的queries初始化是非常重要的一个问题,其影响了DETR及其变体的收敛速度以及稳定性,在vanilla DETR中,queries是由初始化为0的content_queries和正弦余弦编码的pos_queries组成,这是最初始的queries定义,后面许多文章着重于寻求一种能够使DETR加快收敛和稳定性的queries。
# encoder输出经过attention后的特征序列
memory = self.encoder(
query=feat_flatten,
key=None,
value=None,
query_pos=lvl_pos_embed_flatten,
query_key_padding_mask=mask_flatten,
spatial_shapes=spatial_shapes,
reference_points=reference_points,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
**kwargs)
memory = memory.permute(1, 0, 2)
bs, _, c = memory.shape
# 选择two-stage,也就是通过encoder初始预测结果初始化后面的decoder,这里为true
if self.as_two_stage:
output_memory, output_proposals = \
self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes)
enc_outputs_class = cls_branches[self.decoder.num_layers](
output_memory)
enc_outputs_coord_unact = \
reg_branches[
self.decoder.num_layers](output_memory) + output_proposals
topk = self.two_stage_num_proposals
# This follows the official implementation of Deformable DETR.
topk_proposals = torch.topk(
enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
# 这里首先对coords也就是预测的初始位置进行sin cos位置编码,然后进行映射,最后对初始化后的query进行laynorm
# self.pos_trans_norm为layernorm()
# pos_trans为linear映射层
pos_trans_out = self.pos_trans_norm(
self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
query_pos, query = torch.split(pos_trans_out, c, dim=2)
def get_proposal_pos_embed(self,
proposals,
num_pos_feats=128,
temperature=10000):
"""Get the position embedding of proposal."""
scale = 2 * math.pi
dim_t = torch.arange(
num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = proposals.sigmoid() * scale
# N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2 -> N, L, 512
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
dim=4).flatten(2)
return pos