Group DETR:分组一对多匹配是加速DETR收敛的关键学习笔记

Group DETR:分组一对多匹配是加速DETR收敛的关键学习笔记_第1张图片

论文地址:https://arxiv.org/pdf/2207.13085.pdf

代码地址:GitHub - Atten4Vis/ConditionalDETR: This repository is an official implementation of the ICCV 2021 paper "Conditional DETR for Fast Training Convergence". (https://arxiv.org/abs/2108.06152)

代码集成在conditional detr里,使用group detr的分支即可

端到端的物体检测算法 DETR 不需要手工设计的后处理过程 (例如:NMS) ,但是它需要较长的训练才能收敛。在这篇文章中,作者重新考虑了 DETR 收敛慢的问题,作者发现 DETR 中使用的一对一标签分配在一定程度上造成了这样的结果。简单来说,一对一标签匹配使得 DETR 在训练过程中缺少监督信号(因为 positive object query 的数目较少),从而需要延长训练时间来达到较好的效果。实际上,一对多标签分配可以解决缺少监督信号的问题,使得网络收敛更快。但一对多标签分配需要借助 NMS 来去除重复的预测,这有悖于 DETR 系列的端到端这一优雅的设计。

为了解决这一问题,作者提出了 Group DETR。为 DETR 系列算法提供了一种新的标签分配策略:分组一对多标签分配(Group-wise One-to-Many Assigment)。作者提出的算法巧妙地将“一对多分配”问题解耦成“多组的一对一分配”问题。在训练时,使用 K 组 query,每一组独立地进行一对一标签分配,这样总体上每个ground truth会和K个query 匹配。Group DETR 可以加快 DETR 系列算法的收敛,在保证支持 multiple positive query 的同时,去除冗余预测,实现端到端检测。 作者在 DETR 的若干变体上进行实验,包括 Conditional DETR,DAB-DETR,DN-DETR,DINO,以及 Mask2Former,inference 时没有增加任何开销,但获得了显著的训练收敛加速和性能提升。


简单总结:

        就是将输入decoder的query由300拓展到300*11(group)共3300,将其同时输入到decoder中一起计算。在计算loss时,在匈牙利匹配阶段会拆分为11组分别进行匹配,最后将各组的匹配结果,也就是索引值,再加上所在 组数*300,得到最后的索引并进行合并。由于每组query参数不同,所以匹配的结果也是不同,以此来模拟一对多的匹配过程,最后再统一计算其他loss。


        一对一分配比较优雅,但性能有限;一对多分配能暴力提升性能,但需要 NMS 后处理。本文考虑在不使用NMS的情况下,同时利用一对多标签分配算法,在充分利用 positive queries 的同时,也不增加 inference 开销。

        Group DETR 的核心思想是将一个 ground truth 分配给多个 positive queries。为了解决 duplicate prediction 的问题,作者巧妙地将“一对多标签分配”问题解耦成“多组一对一标签分配”问题。如图 (b) 所示,在训练时,作者使用 K 组 query 作为 decoder 的输入。在每组 query 内部执行 self-attention 操作 (参数是共享的),然后每一组 query 输入到 decoder 的剩余部分。在标签分配时,对每一组应用一对一标签分配算法,这样每个 ground truth 会被分配给 K 个 positive queries。在测试的时候,只有第一组 query 被保留 (或任选一组保留,每一组的结果都差不多),因此不改变原有算法的任何流程,也不带来任何计算开销。

Group DETR:分组一对多匹配是加速DETR收敛的关键学习笔记_第2张图片

这里还是直接分解代码会比较直观

Backbone

        backbone部分用的默认resnet50,这里的backbone也就是送入transformer之前用来提取图像特征图的骨架,所以张量在经过resnet50的卷积后得到的特征图通道数由原来的3通道变为2048,W*H = W/32 * H/32,在将特征传入encoder前会将通道降至256。

class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool):
        backbone = getattr(torchvision.models, name)(
            replace_stride_with_dilation=[False, False, dilation],
            pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)


class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)
        out: List[NestedTensor] = []
        pos = []
        for name, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self[1](x).to(x.tensors.dtype))

        return out, pos


def build_backbone(args):
    position_embedding = build_position_encoding(args)
    train_backbone = args.lr_backbone > 0
    return_interm_layers = args.masks
    backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
    model = Joiner(backbone, position_embedding)
    model.num_channels = backbone.num_channels
    return model

        这里假设输入的图像大小为H=1130,W=768,经过backbone之后会得到一个tuple,一个为dict,dict中为mask shape=[2,36,24]和图像卷积后得到的特征 shape=[2,2048,36,24],还有一个pos,为mask经过pose embed得到,shape=[2,256,36,24]

        其中的mask是根据一个batch中长和宽的最大值生成,如上假设,其中一张图的尺寸为H=1130,W=768,而另一张图的尺寸都小于这两个值,那么就会根据这个最大尺寸生成一张模板,将图片贴在模板的左上角,右下角图像无法完全填充的部分即作为padding。在未被图像填充的部分用True表示,被图像填充的部分用False表示。大致示意如下图:

Group DETR:分组一对多匹配是加速DETR收敛的关键学习笔记_第3张图片

 生成上图的代码如下:

def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:
        if torchvision._is_tracing():
            # nested_tensor_from_tensor_list() does not export well to ONNX
            # call _onnx_nested_tensor_from_tensor_list() instead
            return _onnx_nested_tensor_from_tensor_list(tensor_list)

        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)

DETR

首先总览detr部分的代码

class ConditionalDETR(nn.Module):
    """ This is the Conditional DETR module that performs object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False, group_detr=1):
        """ Initializes the model.
        Parameters:
            backbone: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            num_classes: number of object classes
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
            group_detr: Number of groups to speed detr training. Default is 1.
        """
        super().__init__()
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.query_embed = nn.Embedding(num_queries * group_detr, hidden_dim)
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.backbone = backbone
        self.aux_loss = aux_loss
        self.group_detr = group_detr

        # init prior_prob setting for focal loss
        prior_prob = 0.01
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        self.class_embed.bias.data = torch.ones(num_classes) * bias_value

        # init bbox_mebed
        nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
        nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)

    def forward(self, samples: NestedTensor):
        """ The forward expects a NestedTensor, which consists of:
               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
               - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels

            It returns a dict with the following elements:
               - "pred_logits": the classification logits (including no-object) for all queries.
                                Shape= [batch_size x num_queries x num_classes]
               - "pred_boxes": The normalized boxes coordinates for all queries, represented as
                               (center_x, center_y, width, height). These values are normalized in [0, 1],
                               relative to the size of each individual image (disregarding possible padding).
                               See PostProcess for information on how to retrieve the unnormalized bounding box.
               - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
                                dictionnaries containing the two above keys for each decoder layer.
        """
        if isinstance(samples, (list, torch.Tensor)):
            # samples:{mask,[2,1130,768],tensor_list,[2,3,1130,768]}
            samples = nested_tensor_from_tensor_list(samples)
        # pos为mask经过PositionEmbeddingSine得到的
        features, pos = self.backbone(samples)  # features:{mask,[2,36,24],tensor_list,[2,2048,36,24]},pos:[2,256,36,24]

        src, mask = features[-1].decompose() # 其中mask就是为了记录未padding前的原始图像在padding后的图像中的位置,padding的图像按该batch中最大尺寸生成
        assert mask is not None
        if self.training:
            query_embed_weight = self.query_embed.weight  # 由nn.Embedding生成(300*group,256) group=11
        else:
            # only use one group in inference
            query_embed_weight = self.query_embed.weight[:self.num_queries]

        # transformer的输入:
        # src: backbone输出的特征 [2,256,36,24]
        # mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,36,24]
        # query_embed_weight: 由nn.Embedding生成 [3300,256]
        # pos: 为mask经过PositionEmbeddingSine得到的 [2,256,36,24]
        hs, reference = self.transformer(self.input_proj(src), mask, query_embed_weight, pos[-1]) # self.input_proj-> conv2d(2048,256)
        # hs: 所有中间层decoder的输出 [6,2,3300,256] references: 由nn.Embedding生成query_pos经过MLP之后生成的reference point(中点) [2,3300,2]
        reference_before_sigmoid = inverse_sigmoid(reference)
        outputs_coords = []
        for lvl in range(hs.shape[0]):
            tmp = self.bbox_embed(hs[lvl])  # bbox_embed->Linear(256,256) Linear(256,256) Linear(256,4) # [2,3300,256]->[2,3300,4]
            tmp[..., :2] += reference_before_sigmoid  # tmp的xy加上reference_before_sigmoid(中点)
            outputs_coord = tmp.sigmoid()
            outputs_coords.append(outputs_coord)  # 所有中间层输出的bbox
        outputs_coord = torch.stack(outputs_coords)  # [6,2,3300,4]

        outputs_class = self.class_embed(hs) # class_embed->Linear(256,91) [6,2,3300,256]->[6,2,3300,91]
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out  # 包括decoder最后一层输出的bbox和cls,以及decoder各个中间层(除最后一层外的中间五层)输出的bbox和cls(作为辅助loss)

    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        return [{'pred_logits': a, 'pred_boxes': b}
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

        backbone输出了features和pos,其中features包含了mask和经过resnet50得到的特征,输入为[2,3,1130,768],那么mask输出的shape为[2,36,24],tensor_list的shape为[2,2048,36,24],而pos则是mask经过位置编码后得到的,shape为[2,256,36,24]。

transformer的输入包括以下:

src: backbone输出的特征,在输入transformer之前还需要将src进行降维[2,2048,36,24]-> [2,256,36,24]

mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,36,24]

query_embed_weight: 由nn.Embedding生成 [3300,256]

pos: 为mask经过PositionEmbeddingSine得到的 [2,256,36,24]

class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_queries=300, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False, group_detr=1):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before, 
                                                group_detr=group_detr)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=return_intermediate_dec,
                                          d_model=d_model)

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead
        self.dec_layers = num_decoder_layers

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, mask, query_embed, pos_embed):
        # transformer的输入:
        # src: backbone输出的特征 [2,256,36,24]
        # mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,36,24]
        # query_embed: 由nn.Embedding生成 [3300,256]
        # pos_embed: 为mask经过PositionEmbeddingSine得到的 [2,256,36,24]

        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)  # [2,256,36,24]->[864,2,256]
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)  # [2,256,36,24]->[864,2,256]
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)  # [3300,256]->[3300,2,256]
        mask = mask.flatten(1)  # [2,36,24]->[2,864]

        tgt = torch.zeros_like(query_embed)  # 全0初始化 [3300,2,256]
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        # decoder的输入:
        # tgt: 全0初始化 [3300,2,256]
        # memory: encoder的输出 [864,2,256]
        # mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
        # query_embed: 由nn.Embedding生成 [3300,2,256]
        # pos_embed: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
        hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)
        return hs, references
        # hs: 所有中间层decoder的输出 [6,2,3300,256]
        # references: 由nn.Embedding生成query_pos经过MLP之后生成的reference point(中点) [2,3300,2]

这部分代码比较简单,代码中的关键部分进行了注释。在输入encoder之前对各个输入进行了维度转换。

Encoder

encoder和detr是一样的,具体输入如下:

src: backbone输出的特征 [864,2,256]

mask: None

src_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]

pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]

class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        # src: backbone输出的特征 [864,2,256]
        # mask: None
        # src_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
        # pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
        output = src

        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)

        if self.norm is not None:
            output = self.norm(output)

        return output

class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self,
                     src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None):
        # src: backbone输出的特征 [864,2,256]
        # mask: None
        # src_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
        # pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
        q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

    def forward_pre(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

    def forward(self, src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)

src加上位置编码后计算自注意力,残差输出进行LayerNorm,之后再进行MLP。

代码图解:

Group DETR:分组一对多匹配是加速DETR收敛的关键学习笔记_第4张图片

Decoder

decoder的输入如下:

tgt: 全0初始化 [3300,2,256]

memory: encoder的输出 [864,2,256]

tgt_mask: None

memory_mask: None

tgt_key_padding_mask: None

memory_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]

pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]

query_pos: 由nn.Embedding生成 [3300,2,256]

class TransformerDecoder(nn.Module):

    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate
        self.query_scale = MLP(d_model, d_model, d_model, 2)
        self.ref_point_head = MLP(d_model, d_model, 2, 2)
        for layer_id in range(num_layers - 1):
            self.layers[layer_id + 1].ca_qpos_proj = None

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        # tgt: 全0初始化 [3300,2,256]
        # memory: encoder的输出 [864,2,256]
        # tgt_mask: None
        # memory_mask: None
        # tgt_key_padding_mask: None
        # memory_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
        # pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
        # query_pos: 由nn.Embedding生成 [3300,2,256]

        output = tgt

        intermediate = []
        reference_points_before_sigmoid = self.ref_point_head(query_pos)    # [num_queries, batch_size, 2] ref_point_head->MLP Linear(256,256) Linear(256,2) [3300,2,256]->[3300,2,2]
        reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1) # [3300,2,2]->[2,3300,2]

        for layer_id, layer in enumerate(self.layers):
            obj_center = reference_points[..., :2].transpose(0, 1)  # [num_queries, batch_size, 2] # [2,3300,2]->[3300,2,2]

            # For the first decoder layer, we do not apply transformation over p_s
            if layer_id == 0:
                pos_transformation = 1
            else:
                pos_transformation = self.query_scale(output) # query_scale ->MLP Linear(256,256) Linear(256,256) [3300,2,256]->[3300,2,256]

            # get sine embedding for the query vector
            query_sine_embed = gen_sineembed_for_position(obj_center)  # 经过MLP之后生成的reference point(中点)做sine embed [3300,2,2]->[3300,2,256]
            # apply transformation
            query_sine_embed = query_sine_embed * pos_transformation
            output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
                           is_first=(layer_id == 0))  # output:[3300,2,256]
            if self.return_intermediate:
                intermediate.append(self.norm(output)) # 存放了所有中间层的输出

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return [torch.stack(intermediate).transpose(1, 2), reference_points]

        return output.unsqueeze(0)

 这里没有用reference_points进行动态更新,结合DAB的策略应该还能提点

对于每一层decoder:

class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False, group_detr=1):
        super().__init__()
        # Decoder Self-Attention
        self.sa_qcontent_proj = nn.Linear(d_model, d_model)
        self.sa_qpos_proj = nn.Linear(d_model, d_model)
        self.sa_kcontent_proj = nn.Linear(d_model, d_model)
        self.sa_kpos_proj = nn.Linear(d_model, d_model)
        self.sa_v_proj = nn.Linear(d_model, d_model)
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model)

        # Decoder Cross-Attention
        self.ca_qcontent_proj = nn.Linear(d_model, d_model)
        self.ca_qpos_proj = nn.Linear(d_model, d_model)
        self.ca_kcontent_proj = nn.Linear(d_model, d_model)
        self.ca_kpos_proj = nn.Linear(d_model, d_model)
        self.ca_v_proj = nn.Linear(d_model, d_model)
        self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
        self.cross_attn = MultiheadAttention(d_model*2, nhead, dropout=dropout, vdim=d_model)

        self.nhead = nhead

        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before
        self.group_detr = group_detr

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None,
                     query_sine_embed = None,
                     is_first = False):
        # tgt: 全0初始化,之后更新为decoder的输出 [3300,2,256]
        # memory: encoder的输出 [864,2,256]
        # tgt_mask: None
        # memory_mask: None
        # tgt_key_padding_mask: None
        # memory_key_padding_mask: 记录未padding前的原始图像在padding后的图像中的位置 [2,864]
        # pos: 为mask经过PositionEmbeddingSine得到的 [864,2,256]
        # query_pos: 由nn.Embedding生成 [3300,2,256]
        # query_sine_embed: query_pos经过MLP之后生成的reference point(中点)做sine embed [3300,2,256]

        # ========== Begin of Self-Attention =============
        # Apply projections here
        # shape: num_queries x batch_size x 256
        q_content = self.sa_qcontent_proj(tgt)  # [3300,2,256]->[3300,2,256] # target is the input of the first decoder layer. zero by default.
        q_pos = self.sa_qpos_proj(query_pos)  # [3300,2,256]->[3300,2,256]
        k_content = self.sa_kcontent_proj(tgt) # [3300,2,256]->[3300,2,256]
        k_pos = self.sa_kpos_proj(query_pos) # [3300,2,256]->[3300,2,256]
        v = self.sa_v_proj(tgt) # [3300,2,256]->[3300,2,256]
        # xxx_proj都是Linear(256,256)
        num_queries, bs, n_model = q_content.shape
        hw, _, _ = k_content.shape

        q = q_content + q_pos  # [3300,2,256]
        k = k_content + k_pos  # [3300,2,256]

        if self.training:
            q = torch.cat(q.split(num_queries // self.group_detr, dim=0), dim=1)  # [3300,2,256]->[300,22,256]
            k = torch.cat(k.split(num_queries // self.group_detr, dim=0), dim=1)  # [3300,2,256]->[300,22,256]
            v = torch.cat(v.split(num_queries // self.group_detr, dim=0), dim=1)  # [3300,2,256]->[300,22,256]
            

        tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        
        if self.training:
            tgt2 = torch.cat(tgt2.split(bs, dim=1), dim=0)  # [300,22,256]->[3300,2,256]
        # ========== End of Self-Attention =============

        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # ========== Begin of Cross-Attention =============
        # Apply projections here
        # shape: num_queries x batch_size x 256
        q_content = self.ca_qcontent_proj(tgt)  # [3300,2,256]->[3300,2,256]
        k_content = self.ca_kcontent_proj(memory)  # [864,2,256]->[864,2,256]
        v = self.ca_v_proj(memory)  # [864,2,256]->[864,2,256]

        num_queries, bs, n_model = q_content.shape
        hw, _, _ = k_content.shape

        k_pos = self.ca_kpos_proj(pos)  # 对位置编码Linear(256,256)  [864,2,256]->[864,2,256]

        # For the first decoder layer, we concatenate the positional embedding predicted from 
        # the object query (the positional embedding) into the original query (key) in DETR.
        if is_first:
            q_pos = self.ca_qpos_proj(query_pos) # ca_qpos_proj->Linear(256,256)  # [3300,2,256]->[3300,2,256]
            q = q_content + q_pos  # [3300,2,256]
            k = k_content + k_pos  # [3300,2,256]
        else:
            q = q_content
            k = k_content

        q = q.view(num_queries, bs, self.nhead, n_model//self.nhead) # [3300,2,256]->[3300,2,8,32]
        query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed) # ca_qpos_sine_proj->Linear(256,256) [3300,2,256]->[3300,2,256]
        query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead) # [3300,2,256]->[3300,2,8,32]
        q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2) # [3300,2,256*2]
        k = k.view(hw, bs, self.nhead, n_model//self.nhead) # [864,2,256]->[864,2,8,32]
        k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead) # [864,2,256]->[864,2,8,32]
        k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2) # [864,2,256*2]
        # q:[3300,2,512] k:[864,2,512] v:[864,2,256] -> tgt2:[3300,2,256]
        tgt2 = self.cross_attn(query=q,
                                   key=k,
                                   value=v, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]               
        # ========== End of Cross-Attention =============

        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward_pre(self, tgt, memory,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None,
                query_sine_embed = None,
                is_first = False):
        if self.normalize_before:
            raise NotImplementedError
            return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
                                    tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
        return self.forward_post(tgt, memory, tgt_mask, memory_mask,
                                 tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, query_sine_embed, is_first)

可以看到:

if self.training:
    q = torch.cat(q.split(num_queries // self.group_detr, dim=0), dim=1)  # [3300,2,256]->[300,22,256]
    k = torch.cat(k.split(num_queries // self.group_detr, dim=0), dim=1)  # [3300,2,256]->[300,22,256]
    v = torch.cat(v.split(num_queries // self.group_detr, dim=0), dim=1)  # [3300,2,256]->[300,22,256]

在训练时将输入的query拆分成11组,并在batch的维度上进行合并。

伪代码:

Group DETR:分组一对多匹配是加速DETR收敛的关键学习笔记_第5张图片

代码图解:

Group DETR:分组一对多匹配是加速DETR收敛的关键学习笔记_第6张图片

transformer的整体:

Group DETR:分组一对多匹配是加速DETR收敛的关键学习笔记_第7张图片

Loss

1、匈牙利匹配

        匈牙利匹配就是将网络最后的预测结果(类别,bbox)进行加权构成cost矩阵,其中包括cls,bbox的L1和bbox的giou,得到一对一的匹配结果。

class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network
    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
        """Creates the matcher
        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

    @torch.no_grad()
    def forward(self, outputs, targets, group_detr=1):
        """ Performs the matching
        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                           objects in the target) containing the class labels
                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
            group_detr: Number of groups used for matching.
        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        bs, num_queries = outputs["pred_logits"].shape[:2]  # num_queries: 3300 bs: 2

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()  # [batch_size * num_queries, num_classes] # [2,3300,91]->[6600,91]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4] # [2,3300,4]->[6600,4]

        # Also concat the target labels and boxes
        tgt_ids = torch.cat([v["labels"] for v in targets])  # 该batch上所有GT label
        tgt_bbox = torch.cat([v["boxes"] for v in targets])  # 该batch上所有GT bbox

        # Compute the classification cost.
        alpha = 0.25
        gamma = 2.0
        neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
        pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
        cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]  # [6600,len(tgt_ids)]

        # Compute the L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)  # out_bbox:[6600,4] tgt_bbox:[len(tgt_bbox),4] -> [6600,len(tgt_ids)]

        # Compute the giou cost betwen boxes
        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))  # [6600,len(tgt_ids)]

        # Final cost matrix
        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.view(bs, num_queries, -1).cpu() # [6600,len(tgt_ids)]->[2,3300,len(tgt_ids)]

        sizes = [len(v["boxes"]) for v in targets] # batch内每张图片中的GT box个数
        indices = []
        g_num_queries = num_queries // group_detr  # 300
        C_list = C.split(g_num_queries, dim=1)  # tuple([2,300,len(tgt_ids)] * 11)
        for g_i in range(group_detr):  # 遍历11组query
            C_g = C_list[g_i]
            # indices_g 列表中存放的是两个tuple,tuple中的两个元素分别代表匈牙利匹配得到的最优解的横 纵坐标
            # 匈牙利算法的实现,指派最优的目标索引,输出一个二维列表,第一维是batch为0,即一个batch中第一张图像通过匈
            # 牙利算法计算得到的最优解的横纵坐标,第二维是batch为1,即一个batch中第二张图像的横纵坐标
            indices_g = [linear_sum_assignment(c[i]) for i, c in enumerate(C_g.split(sizes, -1))]
            if g_i == 0:
                indices = indices_g
            else:
                indices = [
                    (np.concatenate([indice1[0], indice2[0] + g_num_queries * g_i]), np.concatenate([indice1[1], indice2[1]]))
                    for indice1, indice2 in zip(indices, indices_g)
                ]  # 除了第一组外的其他组的横坐标索引要加上g_num_queries * g_i
        # 最后输出的indices需要转换为torch tensor
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

其中在计算最优解时,会将query的11个组拆分开,得到最优匹配的结果后,在结果的row索引上加上对应的组数*300,col上的索引不变。

2、类别,bbox loss

class SetCriterion(nn.Module):
    """ This class computes the loss for Conditional DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses, group_detr=1):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            losses: list of all the losses to be applied. See get_loss for list of available losses.
            focal_alpha: alpha in Focal Loss
            group_detr: Number of groups to speed detr training. Default is 1.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.focal_alpha = focal_alpha
        self.group_detr = group_detr
        

    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (Binary focal loss)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']
        # idx=(batch_idx, src_idx)
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) # target_classes_o由targets["labels"] 根据 indices的纵坐标重新排序得到
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)  # 构建一个[2,3300]值全为91的张量
        target_classes[idx] = target_classes_o # 根据idx将target_classes_o中的值插入到[2,3300]值为91的张量中
        # one hot编码
        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1],
                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) # [2,3300,92]
        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)

        target_classes_onehot = target_classes_onehot[:,:,:-1]
        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx] # 根据idx提取预测输出outputs['pred_boxes']中的对应bbox
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) # target_boxes由targets['boxes'] 根据 indices的纵坐标重新排序得到

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
            box_ops.box_cxcywh_to_xyxy(src_boxes),
            box_ops.box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses

    def loss_masks(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the masks: the focal loss and the dice loss.
           targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
        """
        assert "pred_masks" in outputs

        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)
        src_masks = outputs["pred_masks"]
        src_masks = src_masks[src_idx]
        masks = [t["masks"] for t in targets]
        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
        target_masks = target_masks.to(src_masks)
        target_masks = target_masks[tgt_idx]

        # upsample predictions to the target size
        src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
                                mode="bilinear", align_corners=False)
        src_masks = src_masks[:, 0].flatten(1)

        target_masks = target_masks.flatten(1)
        target_masks = target_masks.view(src_masks.shape)
        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
            "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
        }
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices  # indices是一个列表,其中的元素是一个tuple,对于batch=2,就是两个tuple,每个tuple中存放的是匈牙利匹配得到的预测目标横纵坐标的索引值,其中任何一维(横或纵)的长度表示了该batch上目标的个数,以此长度构成batch_idx
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices]) # src_idx则表示匈牙利算法得到的横(row上)坐标信息
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
            'masks': self.loss_masks
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        group_detr = self.group_detr if self.training else 1
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets, group_detr=group_detr)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets) * group_detr
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if 'aux_outputs' in outputs:  # 计算除decder最后一层输出外的中间层输出的loss(box,label)
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets, group_detr=group_detr)
                for loss in self.losses:
                    if loss == 'masks':
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    kwargs = {}
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs = {'log': False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses

你可能感兴趣的:(Transformer,学习,笔记,transformer,深度学习,算法,人工智能)