论文地址:https://arxiv.org/pdf/2207.13085.pdf
代码地址:GitHub - Atten4Vis/ConditionalDETR: This repository is an official implementation of the ICCV 2021 paper "Conditional DETR for Fast Training Convergence". (https://arxiv.org/abs/2108.06152)
代码集成在conditional detr里,使用group detr的分支即可
端到端的物体检测算法 DETR 不需要手工设计的后处理过程 (例如:NMS) ,但是它需要较长的训练才能收敛。在这篇文章中,作者重新考虑了 DETR 收敛慢的问题,作者发现 DETR 中使用的一对一标签分配在一定程度上造成了这样的结果。简单来说,一对一标签匹配使得 DETR 在训练过程中缺少监督信号(因为 positive object query 的数目较少),从而需要延长训练时间来达到较好的效果。实际上,一对多标签分配可以解决缺少监督信号的问题,使得网络收敛更快。但一对多标签分配需要借助 NMS 来去除重复的预测,这有悖于 DETR 系列的端到端这一优雅的设计。
为了解决这一问题,作者提出了 Group DETR。为 DETR 系列算法提供了一种新的标签分配策略:分组一对多标签分配(Group-wise One-to-Many Assigment)。作者提出的算法巧妙地将“一对多分配”问题解耦成“多组的一对一分配”问题。在训练时,使用 K 组 query,每一组独立地进行一对一标签分配,这样总体上每个ground truth会和K个query 匹配。Group DETR 可以加快 DETR 系列算法的收敛,在保证支持 multiple positive query 的同时,去除冗余预测,实现端到端检测。 作者在 DETR 的若干变体上进行实验,包括 Conditional DETR,DAB-DETR,DN-DETR,DINO,以及 Mask2Former,inference 时没有增加任何开销,但获得了显著的训练收敛加速和性能提升。
简单总结:
就是将输入decoder的query由300拓展到300*11(group)共3300,将其同时输入到decoder中一起计算。在计算loss时,在匈牙利匹配阶段会拆分为11组分别进行匹配,最后将各组的匹配结果,也就是索引值,再加上所在 组数*300,得到最后的索引并进行合并。由于每组query参数不同,所以匹配的结果也是不同,以此来模拟一对多的匹配过程,最后再统一计算其他loss。
一对一分配比较优雅,但性能有限;一对多分配能暴力提升性能,但需要 NMS 后处理。本文考虑在不使用NMS的情况下,同时利用一对多标签分配算法,在充分利用 positive queries 的同时,也不增加 inference 开销。
Group DETR 的核心思想是将一个 ground truth 分配给多个 positive queries。为了解决 duplicate prediction 的问题,作者巧妙地将“一对多标签分配”问题解耦成“多组一对一标签分配”问题。如图 (b) 所示,在训练时,作者使用 K 组 query 作为 decoder 的输入。在每组 query 内部执行 self-attention 操作 (参数是共享的),然后每一组 query 输入到 decoder 的剩余部分。在标签分配时,对每一组应用一对一标签分配算法,这样每个 ground truth 会被分配给 K 个 positive queries。在测试的时候,只有第一组 query 被保留 (或任选一组保留,每一组的结果都差不多),因此不改变原有算法的任何流程,也不带来任何计算开销。
这里还是直接分解代码会比较直观
backbone部分用的默认resnet50,这里的backbone也就是送入transformer之前用来提取图像特征图的骨架,所以张量在经过resnet50的卷积后得到的特征图通道数由原来的3通道变为2048,W*H = W/32 * H/32,在将特征传入encoder前会将通道降至256。
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.tensors.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model
这里假设输入的图像大小为H=1130,W=768,经过backbone之后会得到一个tuple,一个为dict,dict中为mask shape=[2,36,24]和图像卷积后得到的特征 shape=[2,2048,36,24],还有一个pos,为mask经过pose embed得到,shape=[2,256,36,24]
其中的mask是根据一个batch中长和宽的最大值生成,如上假设,其中一张图的尺寸为H=1130,W=768,而另一张图的尺寸都小于这两个值,那么就会根据这个最大尺寸生成一张模板,将图片贴在模板的左上角,右下角图像无法完全填充的部分即作为padding。在未被图像填充的部分用True表示,被图像填充的部分用False表示。大致示意如下图:
生成上图的代码如下:
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
首先总览detr部分的代码
class ConditionalDETR(nn.Module):
""" This is the Conditional DETR module that performs object detection """
def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False, group_detr=1):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
group_detr: Number of groups to speed detr training. Default is 1.
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
hidden_dim = transformer.d_model
self.class_embed = nn.Linear(hidden_dim, num_classes)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.query_embed = nn.Embedding(num_queries * group_detr, hidden_dim)
self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
self.backbone = backbone
self.aux_loss = aux_loss
self.group_detr = group_detr
# init prior_prob setting for focal loss
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
self.class_embed.bias.data = torch.ones(num_classes) * bias_value
# init bbox_mebed
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
def forward(self, samples: NestedTensor):
""" The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if isinstance(samples, (list, torch.Tensor)):
# samples:{mask,[2,1130,768],tensor_list,[2,3,1130,768]}
samples = nested_tensor_from_tensor_list(samples)
# pos为mask经过PositionEmbeddingSine得到的
features, pos = self.backbone(samples) # features:{mask,[2,36,24],tensor_list,[2,2048,36,24]},pos:[2,256,36,24]
src, mask = features[-1].decompose() # 其中mask就是为了记录未padding前的原始图像在padding后的图像中的位置,padding的图像按该batch中最大尺寸生成
assert mask is not None
if self.training:
query_embed_weight = self.query_embed.weight # 由nn.Embedding生成(300*group,256) group=11
else:
# only use one group in inference
query_embed_weight = self.query_embed.weight[:self.num_queries]
# transformer的输入:
# src: backbone输出的特征 [2,256,36,24]
# mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,36,24]
# query_embed_weight: 由nn.Embedding生成 [3300,256]
# pos: 为mask经过PositionEmbeddingSine得到的 [2,256,36,24]
hs, reference = self.transformer(self.input_proj(src), mask, query_embed_weight, pos[-1]) # self.input_proj-> conv2d(2048,256)
# hs: 所有中间层decoder的输出 [6,2,3300,256] references: 由nn.Embedding生成query_pos经过MLP之后生成的reference point(中点) [2,3300,2]
reference_before_sigmoid = inverse_sigmoid(reference)
outputs_coords = []
for lvl in range(hs.shape[0]):
tmp = self.bbox_embed(hs[lvl]) # bbox_embed->Linear(256,256) Linear(256,256) Linear(256,4) # [2,3300,256]->[2,3300,4]
tmp[..., :2] += reference_before_sigmoid # tmp的xy加上reference_before_sigmoid(中点)
outputs_coord = tmp.sigmoid()
outputs_coords.append(outputs_coord) # 所有中间层输出的bbox
outputs_coord = torch.stack(outputs_coords) # [6,2,3300,4]
outputs_class = self.class_embed(hs) # class_embed->Linear(256,91) [6,2,3300,256]->[6,2,3300,91]
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
return out # 包括decoder最后一层输出的bbox和cls,以及decoder各个中间层(除最后一层外的中间五层)输出的bbox和cls(作为辅助loss)
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
backbone输出了features和pos,其中features包含了mask和经过resnet50得到的特征,输入为[2,3,1130,768],那么mask输出的shape为[2,36,24],tensor_list的shape为[2,2048,36,24],而pos则是mask经过位置编码后得到的,shape为[2,256,36,24]。
transformer的输入包括以下:
src: backbone输出的特征,在输入transformer之前还需要将src进行降维[2,2048,36,24]-> [2,256,36,24]
mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,36,24]
query_embed_weight: 由nn.Embedding生成 [3300,256]
pos: 为mask经过PositionEmbeddingSine得到的 [2,256,36,24]
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_queries=300, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,
return_intermediate_dec=False, group_detr=1):
super().__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before,
group_detr=group_detr)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
self.dec_layers = num_decoder_layers
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed):
# transformer的输入:
# src: backbone输出的特征 [2,256,36,24]
# mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,36,24]
# query_embed: 由nn.Embedding生成 [3300,256]
# pos_embed: 为mask经过PositionEmbeddingSine得到的 [2,256,36,24]
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1) # [2,256,36,24]->[864,2,256]
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) # [2,256,36,24]->[864,2,256]
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [3300,256]->[3300,2,256]
mask = mask.flatten(1) # [2,36,24]->[2,864]
tgt = torch.zeros_like(query_embed) # 全0初始化 [3300,2,256]
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
# decoder的输入:
# tgt: 全0初始化 [3300,2,256]
# memory: encoder的输出 [864,2,256]
# mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
# query_embed: 由nn.Embedding生成 [3300,2,256]
# pos_embed: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
return hs, references
# hs: 所有中间层decoder的输出 [6,2,3300,256]
# references: 由nn.Embedding生成query_pos经过MLP之后生成的reference point(中点) [2,3300,2]
这部分代码比较简单,代码中的关键部分进行了注释。在输入encoder之前对各个输入进行了维度转换。
encoder和detr是一样的,具体输入如下:
src: backbone输出的特征 [864,2,256]
mask: None
src_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
# src: backbone输出的特征 [864,2,256]
# mask: None
# src_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
# pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
output = src
for layer in self.layers:
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
# src: backbone输出的特征 [864,2,256]
# mask: None
# src_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
# pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
src加上位置编码后计算自注意力,残差输出进行LayerNorm,之后再进行MLP。
代码图解:
decoder的输入如下:
tgt: 全0初始化 [3300,2,256]
memory: encoder的输出 [864,2,256]
tgt_mask: None
memory_mask: None
tgt_key_padding_mask: None
memory_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
query_pos: 由nn.Embedding生成 [3300,2,256]
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
self.query_scale = MLP(d_model, d_model, d_model, 2)
self.ref_point_head = MLP(d_model, d_model, 2, 2)
for layer_id in range(num_layers - 1):
self.layers[layer_id + 1].ca_qpos_proj = None
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# tgt: 全0初始化 [3300,2,256]
# memory: encoder的输出 [864,2,256]
# tgt_mask: None
# memory_mask: None
# tgt_key_padding_mask: None
# memory_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
# pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
# query_pos: 由nn.Embedding生成 [3300,2,256]
output = tgt
intermediate = []
reference_points_before_sigmoid = self.ref_point_head(query_pos) # [num_queries, batch_size, 2] ref_point_head->MLP Linear(256,256) Linear(256,2) [3300,2,256]->[3300,2,2]
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1) # [3300,2,2]->[2,3300,2]
for layer_id, layer in enumerate(self.layers):
obj_center = reference_points[..., :2].transpose(0, 1) # [num_queries, batch_size, 2] # [2,3300,2]->[3300,2,2]
# For the first decoder layer, we do not apply transformation over p_s
if layer_id == 0:
pos_transformation = 1
else:
pos_transformation = self.query_scale(output) # query_scale ->MLP Linear(256,256) Linear(256,256) [3300,2,256]->[3300,2,256]
# get sine embedding for the query vector
query_sine_embed = gen_sineembed_for_position(obj_center) # 经过MLP之后生成的reference point(中点)做sine embed [3300,2,2]->[3300,2,256]
# apply transformation
query_sine_embed = query_sine_embed * pos_transformation
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
is_first=(layer_id == 0)) # output:[3300,2,256]
if self.return_intermediate:
intermediate.append(self.norm(output)) # 存放了所有中间层的输出
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return [torch.stack(intermediate).transpose(1, 2), reference_points]
return output.unsqueeze(0)
这里没有用reference_points进行动态更新,结合DAB的策略应该还能提点
对于每一层decoder:
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False, group_detr=1):
super().__init__()
# Decoder Self-Attention
self.sa_qcontent_proj = nn.Linear(d_model, d_model)
self.sa_qpos_proj = nn.Linear(d_model, d_model)
self.sa_kcontent_proj = nn.Linear(d_model, d_model)
self.sa_kpos_proj = nn.Linear(d_model, d_model)
self.sa_v_proj = nn.Linear(d_model, d_model)
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model)
# Decoder Cross-Attention
self.ca_qcontent_proj = nn.Linear(d_model, d_model)
self.ca_qpos_proj = nn.Linear(d_model, d_model)
self.ca_kcontent_proj = nn.Linear(d_model, d_model)
self.ca_kpos_proj = nn.Linear(d_model, d_model)
self.ca_v_proj = nn.Linear(d_model, d_model)
self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
self.cross_attn = MultiheadAttention(d_model*2, nhead, dropout=dropout, vdim=d_model)
self.nhead = nhead
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.group_detr = group_detr
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
query_sine_embed = None,
is_first = False):
# tgt: 全0初始化,之后更新为decoder的输出 [3300,2,256]
# memory: encoder的输出 [864,2,256]
# tgt_mask: None
# memory_mask: None
# tgt_key_padding_mask: None
# memory_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
# pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
# query_pos: 由nn.Embedding生成 [3300,2,256]
# query_sine_embed: query_pos经过MLP之后生成的reference point(中点)做sine embed [3300,2,256]
# ========== Begin of Self-Attention =============
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.sa_qcontent_proj(tgt) # [3300,2,256]->[3300,2,256] # target is the input of the first decoder layer. zero by default.
q_pos = self.sa_qpos_proj(query_pos) # [3300,2,256]->[3300,2,256]
k_content = self.sa_kcontent_proj(tgt) # [3300,2,256]->[3300,2,256]
k_pos = self.sa_kpos_proj(query_pos) # [3300,2,256]->[3300,2,256]
v = self.sa_v_proj(tgt) # [3300,2,256]->[3300,2,256]
# xxx_proj都是Linear(256,256)
num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape
q = q_content + q_pos # [3300,2,256]
k = k_content + k_pos # [3300,2,256]
if self.training:
q = torch.cat(q.split(num_queries // self.group_detr, dim=0), dim=1) # [3300,2,256]->[300,22,256]
k = torch.cat(k.split(num_queries // self.group_detr, dim=0), dim=1) # [3300,2,256]->[300,22,256]
v = torch.cat(v.split(num_queries // self.group_detr, dim=0), dim=1) # [3300,2,256]->[300,22,256]
tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
if self.training:
tgt2 = torch.cat(tgt2.split(bs, dim=1), dim=0) # [300,22,256]->[3300,2,256]
# ========== End of Self-Attention =============
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# ========== Begin of Cross-Attention =============
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.ca_qcontent_proj(tgt) # [3300,2,256]->[3300,2,256]
k_content = self.ca_kcontent_proj(memory) # [864,2,256]->[864,2,256]
v = self.ca_v_proj(memory) # [864,2,256]->[864,2,256]
num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape
k_pos = self.ca_kpos_proj(pos) # 对位置编码Linear(256,256) [864,2,256]->[864,2,256]
# For the first decoder layer, we concatenate the positional embedding predicted from
# the object query (the positional embedding) into the original query (key) in DETR.
if is_first:
q_pos = self.ca_qpos_proj(query_pos) # ca_qpos_proj->Linear(256,256) # [3300,2,256]->[3300,2,256]
q = q_content + q_pos # [3300,2,256]
k = k_content + k_pos # [3300,2,256]
else:
q = q_content
k = k_content
q = q.view(num_queries, bs, self.nhead, n_model//self.nhead) # [3300,2,256]->[3300,2,8,32]
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed) # ca_qpos_sine_proj->Linear(256,256) [3300,2,256]->[3300,2,256]
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead) # [3300,2,256]->[3300,2,8,32]
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2) # [3300,2,256*2]
k = k.view(hw, bs, self.nhead, n_model//self.nhead) # [864,2,256]->[864,2,8,32]
k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead) # [864,2,256]->[864,2,8,32]
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2) # [864,2,256*2]
# q:[3300,2,512] k:[864,2,512] v:[864,2,256] -> tgt2:[3300,2,256]
tgt2 = self.cross_attn(query=q,
key=k,
value=v, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
# ========== End of Cross-Attention =============
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
query_sine_embed = None,
is_first = False):
if self.normalize_before:
raise NotImplementedError
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, query_sine_embed, is_first)
可以看到:
if self.training: q = torch.cat(q.split(num_queries // self.group_detr, dim=0), dim=1) # [3300,2,256]->[300,22,256] k = torch.cat(k.split(num_queries // self.group_detr, dim=0), dim=1) # [3300,2,256]->[300,22,256] v = torch.cat(v.split(num_queries // self.group_detr, dim=0), dim=1) # [3300,2,256]->[300,22,256]
在训练时将输入的query拆分成11组,并在batch的维度上进行合并。
伪代码:
代码图解:
transformer的整体:
匈牙利匹配就是将网络最后的预测结果(类别,bbox)进行加权构成cost矩阵,其中包括cls,bbox的L1和bbox的giou,得到一对一的匹配结果。
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
@torch.no_grad()
def forward(self, outputs, targets, group_detr=1):
""" Performs the matching
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
group_detr: Number of groups used for matching.
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, num_queries = outputs["pred_logits"].shape[:2] # num_queries: 3300 bs: 2
# We flatten to compute the cost matrices in a batch
out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] # [2,3300,91]->[6600,91]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] # [2,3300,4]->[6600,4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets]) # 该batch上所有GT label
tgt_bbox = torch.cat([v["boxes"] for v in targets]) # 该batch上所有GT bbox
# Compute the classification cost.
alpha = 0.25
gamma = 2.0
neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] # [6600,len(tgt_ids)]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # out_bbox:[6600,4] tgt_bbox:[len(tgt_bbox),4] -> [6600,len(tgt_ids)]
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) # [6600,len(tgt_ids)]
# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu() # [6600,len(tgt_ids)]->[2,3300,len(tgt_ids)]
sizes = [len(v["boxes"]) for v in targets] # batch内每张图片中的GT box个数
indices = []
g_num_queries = num_queries // group_detr # 300
C_list = C.split(g_num_queries, dim=1) # tuple([2,300,len(tgt_ids)] * 11)
for g_i in range(group_detr): # 遍历11组query
C_g = C_list[g_i]
# indices_g 列表中存放的是两个tuple,tuple中的两个元素分别代表匈牙利匹配得到的最优解的横 纵坐标
# 匈牙利算法的实现,指派最优的目标索引,输出一个二维列表,第一维是batch为0,即一个batch中第一张图像通过匈
# 牙利算法计算得到的最优解的横纵坐标,第二维是batch为1,即一个batch中第二张图像的横纵坐标
indices_g = [linear_sum_assignment(c[i]) for i, c in enumerate(C_g.split(sizes, -1))]
if g_i == 0:
indices = indices_g
else:
indices = [
(np.concatenate([indice1[0], indice2[0] + g_num_queries * g_i]), np.concatenate([indice1[1], indice2[1]]))
for indice1, indice2 in zip(indices, indices_g)
] # 除了第一组外的其他组的横坐标索引要加上g_num_queries * g_i
# 最后输出的indices需要转换为torch tensor
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
其中在计算最优解时,会将query的11个组拆分开,得到最优匹配的结果后,在结果的row索引上加上对应的组数*300,col上的索引不变。
class SetCriterion(nn.Module):
""" This class computes the loss for Conditional DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses, group_detr=1):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
losses: list of all the losses to be applied. See get_loss for list of available losses.
focal_alpha: alpha in Focal Loss
group_detr: Number of groups to speed detr training. Default is 1.
"""
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.losses = losses
self.focal_alpha = focal_alpha
self.group_detr = group_detr
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (Binary focal loss)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
# idx=(batch_idx, src_idx)
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) # target_classes_o由targets["labels"] 根据 indices的纵坐标重新排序得到
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device) # 构建一个[2,3300]值全为91的张量
target_classes[idx] = target_classes_o # 根据idx将target_classes_o中的值插入到[2,3300]值为91的张量中
# one hot编码
target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1],
dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) # [2,3300,92]
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
target_classes_onehot = target_classes_onehot[:,:,:-1]
loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
losses = {'loss_ce': loss_ce}
if log:
# TODO this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
return losses
@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices, num_boxes):
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
"""
pred_logits = outputs['pred_logits']
device = pred_logits.device
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {'cardinality_error': card_err}
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
"""
assert 'pred_boxes' in outputs
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs['pred_boxes'][idx] # 根据idx提取预测输出outputs['pred_boxes']中的对应bbox
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) # target_boxes由targets['boxes'] 根据 indices的纵坐标重新排序得到
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
losses = {}
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
box_ops.box_cxcywh_to_xyxy(src_boxes),
box_ops.box_cxcywh_to_xyxy(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes
return losses
def loss_masks(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the masks: the focal loss and the dice loss.
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
"""
assert "pred_masks" in outputs
src_idx = self._get_src_permutation_idx(indices)
tgt_idx = self._get_tgt_permutation_idx(indices)
src_masks = outputs["pred_masks"]
src_masks = src_masks[src_idx]
masks = [t["masks"] for t in targets]
# TODO use valid to mask invalid areas due to padding in loss
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_masks = target_masks[tgt_idx]
# upsample predictions to the target size
src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
mode="bilinear", align_corners=False)
src_masks = src_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(src_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
}
return losses
def _get_src_permutation_idx(self, indices):
# permute predictions following indices # indices是一个列表,其中的元素是一个tuple,对于batch=2,就是两个tuple,每个tuple中存放的是匈牙利匹配得到的预测目标横纵坐标的索引值,其中任何一维(横或纵)的长度表示了该batch上目标的个数,以此长度构成batch_idx
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices]) # src_idx则表示匈牙利算法得到的横(row上)坐标信息
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
loss_map = {
'labels': self.loss_labels,
'cardinality': self.loss_cardinality,
'boxes': self.loss_boxes,
'masks': self.loss_masks
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def forward(self, outputs, targets):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
group_detr = self.group_detr if self.training else 1
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets, group_detr=group_detr)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets) * group_detr
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
if is_dist_avail_and_initialized():
torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
# Compute all the requested losses
losses = {}
for loss in self.losses:
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in outputs: # 计算除decder最后一层输出外的中间层输出的loss(box,label)
for i, aux_outputs in enumerate(outputs['aux_outputs']):
indices = self.matcher(aux_outputs, targets, group_detr=group_detr)
for loss in self.losses:
if loss == 'masks':
# Intermediate masks losses are too costly to compute, we ignore them.
continue
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs = {'log': False}
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
losses.update(l_dict)
return losses