MaskFormer源码解析

最近MaskFormer的升级版mask2former的源码也开源了,作者不仅论文写的扎实代码也写的整洁工整,不论是论文还是代码都值得我们好好学习。而且作者的实验也做的异常扎实,啥也不说了 瑞士拜!
mask2former相对与前辈maskformer主要是在Decorder阶段做了改动,两个模型在整体架构和损失函数计算上基本保持一致,所以本篇先写MaskFormer,后续再写Mask2Former。其实如果MaskFormer的代码搞懂了,Mask2Former也就基本没问题啦。本篇是MaskFormer的源码解析,论文建议直接看原文,作者写的非常好墙裂推介阅读原文。

一、概述

整个代码结构基于detectron2框架,所以会有很多注册的指令和from_config()函数,这两个都不影响代码的逻辑,在看源码的时候不必纠结,把所有的from_config()都看成从配置文件读取相关变量的值即可,具体的值可以在config/xx/xx.yaml文件种找到;注册指令是为了detectron2可以检测到,看源码的时候可以直接忽略这条指令。有了两个前提,我们就可以来看作者是如何把网络结构图用代码一步步实现的了。

MaskFormer源码解析_第1张图片

MaskForer结构图

二、整体结构

所有核心代码都在maskforer/mask_former/文件夹下。
总共四个核心类分别是:

  • mask_former_model.py中的MaskFormer
  • headers/mask_former_head.py中的MaskFormerHead
  • headers/pixel_decoder.py中的TransformerEncoderPixelDecoder
  • transformer/transformer_predictor.py中的TransformerPredictor
    其中,MaskFormer类是模型的入口,整体结构都在这个类中。在这个类的构造函数中可以看到模型是由backbone、sem_seg_head、criterion构成,其中backbone对应图1中的backbone,sem_seg_head包含了除backbone外的所有啦,criterion是损失函数的类,对应这图1中的两个Loss。

2.1 backbone

backbone没什么好说的,作者使用了resnet和swin两种,backbone不论用的哪种网络只要知道输出的是一个字典,关键字是:‘res2’、'res3’这种就行了,分别对应着backbone提取到的特征。可以参看swin.py的forward()函数,注意看返回值outs,有一句:outs[“res{}”.format(i +2)]= out

def forward(self, x):
    """Forward function."""
    x = self.patch_embed(x)

    Wh, Ww = x.size(2), x.size(3)
    if self.ape:
        # interpolate the position embedding to the corresponding size
        absolute_pos_embed = F.interpolate(
            self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
        )
        x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
    else:
        x = x.flatten(2).transpose(1, 2)
    x = self.pos_drop(x)

    outs = {}
    for i in range(self.num_layers):
        layer = self.layers[i]
        x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)

        if i in self.out_indices:
            norm_layer = getattr(self, f"norm{i}")
            x_out = norm_layer(x_out)

            out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
            outs["res{}".format(i + 2)] = out

    return outs

2.2 MaskFormerHead

根据上面的分析,可以看到整个模型最核心的部分都在MaskFormerHead中,这里提供了图1中的Decoder功能,直接看MaskFormerHead的forward()函数

    def forward(self, features):
        return self.layers(features)

    def layers(self, features):
        mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
        if self.transformer_in_feature == "transformer_encoder":
            assert (
                transformer_encoder_features is not None
            ), "Please use the TransformerEncoderPixelDecoder."
            predictions = self.predictor(transformer_encoder_features, mask_features)
        else:
            predictions = self.predictor(features[self.transformer_in_feature], mask_features)
        return predictions

发现很简单,只是返回了layers()。那就仔细看layers(),里面是predictor根据pixel_decoder的结果直接输出了,辣么核心的类就来到了pixel_decoder和predictor。

2.3 pixel_decoder

    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
        return {
            "input_shape": {
                k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
            },
            "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
            "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
            "pixel_decoder": build_pixel_decoder(cfg, input_shape),
            "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
            "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
            "transformer_predictor": TransformerPredictor(
                cfg,
                cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
                if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
                else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
                mask_classification=True,
            ),
        }

根据from_config()可以看到pixel_decoder是根据配置文件中的cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME指定的(不过我在配置文件中没找到这一项,也可能是我找的地方不对),不过根据代码结构可以定位到mask_former/heads/pixel_decoder.py文件,TransformerEncoderPixelDecoder整个类就是我们要的pixel_decoder,在TransformerEncoderPixelDecoder的forward_features()函数中返回了两个值self.mask_features(y), transformer_encoder_features正好对应着layers()中的
mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)

2.4 predictor

接着看MaskFormerHead的from_config()函数,可以看到有一项是

"transformer_predictor": TransformerPredictor()

OK,我们需要的predictor就是modeling/transformer/transformer_predictor.py中的TransformerPredictor这个类了。

三、模型细节

根据上面的分析,已经把核心代码定位到TransformerEncoderPixelDecoder和TransformerPredictor这两个类了,接下来就仔细研究这两个类。

3.1 TransformerEncoderPixelDecoder

根据类定义可以看到TransformerEncoderPixelDecoder继承自BasePixelDecoder,根据forward_features()的返回值定位到两个关键的函数self.mask_features和self.transformer

    def forward_features(self, features):
        # Reverse feature maps into top-down order (from low to high resolution)
        for idx, f in enumerate(self.in_features[::-1]):
            x = features[f]
            lateral_conv = self.lateral_convs[idx]
            output_conv = self.output_convs[idx]
            if lateral_conv is None:
                transformer = self.input_proj(x)
                pos = self.pe_layer(x)
                transformer = self.transformer(transformer, None, pos)
                y = output_conv(transformer)
                # save intermediate feature as input to Transformer decoder
                transformer_encoder_features = transformer
            else:
                cur_fpn = lateral_conv(x)
                # Following FPN implementation, we use nearest upsampling here
                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
                y = output_conv(y)
        return self.mask_features(y), transformer_encoder_features

self.mask_features()和计算y的lateral_conv都来源于基类BasePixelDecoder,那么就先看BasePixelDecoder。

3.2 BasePixelDecoder

    def forward_features(self, features):
        # Reverse feature maps into top-down order (from low to high resolution)
        for idx, f in enumerate(self.in_features[::-1]):
            x = features[f]
            lateral_conv = self.lateral_convs[idx]
            output_conv = self.output_convs[idx]
            if lateral_conv is None:
                y = output_conv(x)
            else:
                cur_fpn = lateral_conv(x)
                # Following FPN implementation, we use nearest upsampling here
                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
                y = output_conv(y)
        return self.mask_features(y), None

这里仔细看一下就会发现其实就是一个常规的FPN结构,把backbone提取到的特征通过
y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode=“nearest”)
逐步add到一起,如果是最后一层则使用output_conv()输出即可。BasePixelDecoder类的构造函数中基本都是一定算子的定义,在梳理代码逻辑的时候可以不必纠结这些算子的参数,但是有个地方需要特别注意,就是源码的77行,当idx时features的最后一个特征序号的时候既’res5’这个特征,lateral_convs.append(None),因为根据图1,'res5’这个特征是要经过一个TransformerEncoder的(请对照源码modeling/heads/pixel_decoder.py)
所以,TransformerEncoderPixelDecoder返回的第一个参数就是FPN的结果,只是我们一般在做语义分割的时候就直接用这个结果去计算损失和预测了,但是这里这个结果仅仅作为mask的feature,所以这里mask_features的维度也不是我们通常说的class_num,而是一个超参mask_dim,具体值是多少可以在config文件中cfg.MODEL.SEM_SEG_HEAD.MASK_DIM看到,源码中定义的是256。

3.3 生成transformer_encoder_features

    def forward_features(self, features):
        # Reverse feature maps into top-down order (from low to high resolution)
        for idx, f in enumerate(self.in_features[::-1]):
            x = features[f]
            lateral_conv = self.lateral_convs[idx]
            output_conv = self.output_convs[idx]
            if lateral_conv is None:
                transformer = self.input_proj(x)
                pos = self.pe_layer(x)
                transformer = self.transformer(transformer, None, pos)
                y = output_conv(transformer)
                # save intermediate feature as input to Transformer decoder
                transformer_encoder_features = transformer
            else:
                cur_fpn = lateral_conv(x)
                # Following FPN implementation, we use nearest upsampling here
                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
                y = output_conv(y)
        return self.mask_features(y), transformer_encoder_features

还回到TransformerEncoderPixelDecoder的forward_features()函数,一步步的跟踪可以找到transformer_encoder_features是由self.transformer()生成的,这里接收两个参数transformer和pos。transformer由self.input_proj(x)提供,查看self.input_proj的定义,其实就是一个卷积算子,为了将res5的通道数变换到conv_dim;pos这个参数如果熟悉Transformer的结构,肯定不会陌生,就是一个位置Embedding。所以现在的焦点就到了self.transformer这个算子,在构造函数中有一句:

        self.transformer = TransformerEncoderOnly(
            d_model=conv_dim,
            dropout=transformer_dropout,
            nhead=transformer_nheads,
            dim_feedforward=transformer_dim_feedforward,
            num_encoder_layers=transformer_enc_layers,
            normalize_before=transformer_pre_norm,
        )

其实这个self.transformer就是使用 Transformer做了一个编码而已。在TransformerEncoderOnly类的构造函数中使用TransformerEncoderLayer构造了TransformerEncoder,核心的工作都在TransformerEncoderLayer中,定位到TransformerEncoderLayer的forward()函数中,不要被forward_post和forward_pre迷惑,无非就是一个最后noml一个先noml而已,随便看一个函数,其实就是一个标准的self-attention。现在一直到这里如果对标准的Transmorfer有了解的话应该都没什么问题。不过,有个地方需要特别注意,在TransformerEncoderOnly的forward()中有个维度调整的操作!

    def forward(self, src, mask, pos_embed):
        # flatten NxCxHxW to HWxNxC
        bs, c, h, w = src.shape
        src = src.flatten(2).permute(2, 0, 1)
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
        if mask is not None:
            mask = mask.flatten(1)

        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
        return memory.permute(1, 2, 0).view(bs, c, h, w)

首先得到输入的形状[b,c, h, w],然后将输入的feature将[h,w]维度拉平变成了[b, c, hw],然后做了维度调整,所以src输入到Transformer的形状是[hw, b, c]。辣么问题来了,为什么要这样调整?
我们先看一下TransformerEncoderLayer中是如何实现MultiheadAttention的,定位到modeling/transformer/transformer.py中的TransformerEncoderLayer类,可以看到是直接调用了torch的nn.MultiheadAttention,继续跟踪到nn.MultiheadAttention的内部,在MultiheadAttention.forward()中可以看到如下的注释,不用我翻译了吧。

    Shapes for inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.

那么就很清楚了,nn.MultiheadAttention()要求的输入数据就是需要将b放在中间这个维度上,第一个维度L序列的长度,这里就是我们的hw,即我们总共有hw个像素;最后一个维度E就是我们一个像素在所有通道上的值,可以理解成一个像素的embedding,这样我们计算的就是每个像素间的attention啦。

3.4 predictor

现在再回头看看MaskFormerHead的forward()函数,发现就还剩最后的predictor这个模型就输出了。根据2.4节的分析predictor就是TransformerPredictor这个类。如同上面的分析思路,我们直接从这个类的forward()看起。

    def forward(self, x, mask_features):
        pos = self.pe_layer(x)

        src = x
        mask = None
        hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)

        if self.mask_classification:
            outputs_class = self.class_embed(hs)
            out = {"pred_logits": outputs_class[-1]}
        else:
            out = {}

        if self.aux_loss:
            # [l, bs, queries, embed]
            mask_embed = self.mask_embed(hs)
            outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
            out["pred_masks"] = outputs_seg_masks[-1]
            out["aux_outputs"] = self._set_aux_loss(
                outputs_class if self.mask_classification else None, outputs_seg_masks
            )
        else:
            # FIXME h_boxes takes the last one computed, keep this in mind
            # [bs, queries, embed]
            mask_embed = self.mask_embed(hs[-1])
            outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
            out["pred_masks"] = outputs_seg_masks
        return out

可以肯定的是这个类提供了MaskFormer模型的输出,如果这里想搞明白为什么这样输出,就需要先看下论文了。我简单的介绍一下,精华还请参考原论文。
我们在做语义分割的时候,通常做法都是给每张特征图上的像素一个类,所以输出的格式是[b, num_class, h, w],这样输出在语义分割任务中没什么问题,在实例分割任务中如果一张图片中两个同类别的实例有重叠部分就很难分开了。换个思路,输出的特征图变成[b, q, h, w]格式,这里的q是查询的个数,每张特征图都是一个二值化的mask,具体这个mask的类别由另外一个预测类别的分支来确定,预测类别分支的输出格式就是[b, q, num_class]。总结一下,把之前每个像素点一个类别的输出修改成两个分支,一个分支用于预测类别,输出格式是[b, q, num_class];一个分支用于预测一张图像上各个实例的mask,输出格式是[b, q, h, w],其中q代表查询的个数,或者理解成一张图像上最多预测的实例个数。
ok,有了上述的铺垫就可以开开心心的看predictor的输出啦。首先需要确定forward函数中的入参(x, mask_features)是什么,回到2.2节,在调用pixel_decoder之后将得到的返回值输入给了predictor
predictions = self.predictor(transformer_encoder_features, mask_features)
所以在predictor的forward入参中x=transformer_encoder_features,mask_features=mask_features。接下去有一个transformer的函数,在构造函数中有transformer的定义,这里的transformer就是一个标准的Transformer结构了,包含了编码器和解码器,输出的两个值hs和memory分别是Transformer的输出和Transformer中Encoder的输出。(如果对这里不熟悉的,建议先去熟悉Transformer)。这里唯一需要注意的地方在Transformer的Decoder中会把每一层的输出都记录下来,供后续aux_loss计算使用。代码中的self.return_intermediate值由配置文件中的DEEP_SUPERVISION控制,默认为TRUE。

    def forward(
        self,
        tgt,
        memory,
        tgt_mask: Optional[Tensor] = None,
        memory_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        memory_key_padding_mask: Optional[Tensor] = None,
        pos: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):
        output = tgt

        intermediate = []

        for layer in self.layers:
            output = layer(
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                pos=pos,
                query_pos=query_pos,
            )
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output.unsqueeze(0)

transformer()输出的两个值hs,memory搞定了,接着往下看就看见上述说的类别预测分支了

        hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)

        if self.mask_classification:
            outputs_class = self.class_embed(hs)
            out = {"pred_logits": outputs_class[-1]}
        else:
            out = {}

self.class_embed()函数就是类别预测分支啦,self.class_embed = nn.Linear(hidden_dim, num_classes + 1)所以输出的就是[b, q, num_classes + 1]。
接下去就是计算mask的输出,发现有个if self.aux_loss,这里就是说是否利用Transformer.Decoder的各层输出来计算损失还是只使用Transformer的最终输出来计算损失,论文中提到使用aux_loss能涨点,所以还是要利用各层的输出啦。

        if self.aux_loss:
            # [l, bs, queries, embed]
            mask_embed = self.mask_embed(hs)
            outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
            out["pred_masks"] = outputs_seg_masks[-1]
            out["aux_outputs"] = self._set_aux_loss(
                outputs_class if self.mask_classification else None, outputs_seg_masks
            )
        else:
            # FIXME h_boxes takes the last one computed, keep this in mind
            # [bs, queries, embed]
            mask_embed = self.mask_embed(hs[-1])
            outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
            out["pred_masks"] = outputs_seg_masks

mask_embed一个MLP就是为了将Transformer的输出映射成和mask_features的维度一致,然后和mask_features做一次矩阵运算得到最终的mask分支输出,可以看到mask的形状是[b, q, h, w],源码中的l是Transformer.Decoder的层数,在计算aux_loss时有用,可以看到最终的输出"pred_masks"只要了最后一层 out[“pred_masks”]= outputs_seg_masks[-1]那maskformer模型的结构和输出就全部搞定了。

3.5 损失函数计算

本篇论文的精华都在损失函数的地方,作者也说了只是提供了一种统一语义分割和实例分割的思路,作为统一分割方式的一种进行探索并不是说只有本文的这种方式,而如何统一就在于损失函数的计算和网络输出的后续处理。
损失函数的计算在mask_former/modeling/criterion.py文件中。根据3.4节的描述,我们的输出其实是[b, q, h, w]格式,q个[h, w]特征图中并不一定都有输出,那怎么计算呢?额…… 我们先补充点背景知识吧。
假设有N个任务,同时有N个人可以来完成这些任务,但是每个人完成每个任务的耗时(cost)是不同的,那么如何找到一个合理的匹配使得最终每个人分到一个任务且总的耗时(cost)最小?这里就不细说算法细节了,知道匈牙利算法可以用来解决这类问题就行,具体算法细节请参考二部图的匹配算法。
有了上面的背景知识,我们就好办了。现在有q个特征图和i个目标(一张图像上有i个目标,每张图的i不固定),我们需要在q个特征图中找到对应的i个目标,使得总的cost最小,所以我们就在源码中看到了mask_former/modeling/matcher.py文件,源码中实际的匈牙利算法的实现是直接调用了scipy库中的linear_sum_assignment(),阅读源码的时候需要搞清楚每一个变量的形状,然后结合代码的注释基本问题不大,在调用linear_sum_assignment()之前都是为了生成花费矩阵的,这里有一个地方需要特别注意,tgt_ids这个变量的值里面存的是每一张图像中出现的类别序号,比如图像中有人、车、飞机,对应的class_id是1,3,7,则tgt_ids=[1,3,7],剩下的就都比较容易理解啦。

   def memory_efficient_forward(self, outputs, targets):
        """More memory-friendly matching"""
        bs, num_queries = outputs["pred_logits"].shape[:2]

        # Work out the mask padding size
        masks = [v["masks"] for v in targets]
        h_max = max([m.shape[1] for m in masks])
        w_max = max([m.shape[2] for m in masks])

        indices = []

        # Iterate through batch size
        for b in range(bs):

            out_prob = outputs["pred_logits"][b].softmax(-1)  # [num_queries, num_classes]
            out_mask = outputs["pred_masks"][b]  # [num_queries, H_pred, W_pred]

            tgt_ids = targets[b]["labels"]  # [1,2,3, ……]
            # gt masks are already padded when preparing target
            tgt_mask = targets[b]["masks"].to(out_mask)

            # Compute the classification cost. Contrary to the loss, we don't use the NLL,
            # but approximate it in 1 - proba[target class].
            # The 1 is a constant that doesn't change the matching, it can be ommitted.
            cost_class = -out_prob[:, tgt_ids] # [num_queries, num_total_targets]

            # Downsample gt masks to save memory
            tgt_mask = F.interpolate(tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest")

            # Flatten spatial dimension
            out_mask = out_mask.flatten(1)  # [num_queries, H*W]
            tgt_mask = tgt_mask[:, 0].flatten(1)  # [num_total_targets, H*W]

            # Compute the focal loss between masks
            cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask) # [num_queries, num_total_targets]

            # Compute the dice loss betwen masks
            cost_dice = batch_dice_loss(out_mask, tgt_mask)

            # Final cost matrix
            C = (
                self.cost_mask * cost_mask
                + self.cost_class * cost_class
                + self.cost_dice * cost_dice
            )
            C = C.reshape(num_queries, -1).cpu()  # [num_queries, num_total_targets]

            indices.append(linear_sum_assignment(C))
        return [
            (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
            for i, j in indices
        ]

最终送到linear_sum_assignment()函数的C的形状是[num_queries, num_total_targets],函数返回两个序列,第一个序列对应的是行的序号,第二个序列是对应的列序号。还剩最后一个问题,实际中num_queries不会刚好等于num_total_targets,即输入的不是一个方阵,那么输出是什么?当不是方阵的时候,输出的第一个和第二个值的长度等于min(num_queries, num_total_targets),作者在match的forward函数有说明,已经很清楚了。

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """

经过上面的匈牙利匹配算法,我们在输出的query中匹配到了最相似的类别,接下来就是使用得到的query计算损失。回到criterion,分别计算了类别损失和mask的损失,其中mask loss包含了dice loss和每个像素值的focal_loss,只是这里计算focal_loss是需要对每个像素做sigmoid操作的,而不是我们平时做的softmax,因为我们的类别信息是由类别分支提供的,图像上的像素点仅仅提供是否前景还是背景信息,所以用sigmoid函数,类别损失就是正常的CE loss啦。整个损失函数的计算基本上脉络就清楚了,这里有个地方可能容易被迷惑,源码中在计算损失之前有个计算src_idx和tgt_idx的操作,看函数的实现可能一时半会搞不清楚到底在干什么,其实很简单输出的就是类似batch_idx=[1,1,2,3] src_idx = [2, 2,3,4]这种形式。_get_src_permutation_idx返回的是match函数第一个返回值对应的batch的序号,同理_get_tgt_permutation_idx返回的是match函数第二个返回值对应的batch的序号。

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

那整个模型就搞定啦。还剩最后一个在推理阶段怎么输出,如果把上述的分析都看明白后,基本上也就明白怎么弄了,我就不写啦。模型的前向运算输出在mask_former/mask_former_model.py中,具体实现参见self.semantic_inference()和self.panoptic_inference()两个函数。应该问题不大。
最后,还是墙裂建议老老实实的把原论文看一遍,顺带把Mask2Former的文章一起看,一定会有收获的。

你可能感兴趣的:(深度学习,python,语义分割,MaskFormer)