Patcher: Patch Transformers with Mixture of Experts for Precise Medical Image Segmentation

Contextual Transformer Networks for Visual Recognition

  • 1. 摘要
  • 2. 目的
  • 3. 网络设计
    • 3.1 Overall Architecture
    • 3.2 Patcher Block
    • 3.3 Patcher Encoder
    • 3.4 Mixture of Experts Decoder

代码地址

1. 摘要

We present a new encoder-decoder Vision Transformer architecture, Patcher, for medical image segmentation. Unlike standard Vision Transformers, it employs Patcher blocks that segment an image into large patches, each of which is further divided into small patches. Transformers are applied to the small patches within a large patch, which constrains the receptive field of each pixel. We intentionally make the large patches overlap to enhance intra-patch communication. The encoder employs a cascade of Patcher blocks with increasing receptive fields to extract features from local to global levels. This design allows Patcher to benefit from both the coarse-to-fine feature extraction common in CNNs and the superior spatial relationship modeling of Transformers. We also propose a new mixture-of-experts (MoE) based decoder, which treats the feature maps from the encoder as experts and selects a suitable set of expert features to predict the label for each pixel. The use of MoE enables better specializations of the expert features and reduces interference between them during inference. Extensive experiments demonstrate that Patcher outperforms state-of-the-art Transformer- and CNN-based approaches significantly on stroke lesion segmentation and polyp segmentation. Code for Patcher is released to facilitate related research.

我们提出了一种新的用于医学图像分割的编码器-解码器 Vision Transformer 架构——Patcher。与标准的 Vision Transformer 不同,它采用了 Patcher blocks,将图像分割成大的 patches,每个 patch 又进一步细分为小的 patch。Transformer 应用于一个大的 patch 中的小的 patches,这限制了每个像素的 receptive field。我们有意让大 patches 重叠,以加强 intra-patch communication。该编码器采用了一个级联的 Patcher blocks 增加 receptive fields,以提取从局部到全局级别的特征。这种设计使得 Patcher 既可以从 cnn 中常见的从粗到细的特征提取中获益,又可以从 Transformer 的优越空间关系建模中获益。我们还提出了一种新的基于解码器的 mixture-of-experts (MoE),该解码器将来自编码器的特征图视为 experts,并选择合适的 experts 特征集来预测每个像素的标签。使用 MoE 可以更好地特殊化 experts 特征,减少推理过程中 experts 特征之间的干扰。大量的实验表明,Patcher 在脑卒中病变和息肉的分割上明显优于基于 Transformer 和 cnn 的最先进的方法。

2. 目的

到目前为止,之前的大多数工作主要使用 Transformer 来提取补丁级的特征,而不是精细的像素级特征。考虑到 Transformer 在建模 spatial relationships 方面的强大能力,我们相信有机会充分利用 Transformer 来提取细粒度的像素级特征,而无需将其委托给卷积层。

为此,我们提出了一种新的编码器-解码器 Vision Transformer 架构 Patcher,它使用 Transformer 来提取 global features 之外的 fine-grained local features。它的关键组成部分是 Patcher 块,它将图像分割成大的 patches(如 32 × 32),每个 patch 又被细分成小的 patches(如 2 × 2)。对每个大 patch 中的小 patches 应用 Transformer 来提取像素级特征。每个大的 patches 都限制了内部像素的 receptive fields,我们有意地让大的 patches 重叠,以增强 patches 内部的通信。该编码器使用一个级联的 Patcher 块增加 receptive fields,以输出从局部到全局层次提取的特征图序列。此外,我们观察到图像分割模型主要对一些像素(如边缘像素)要求局部特征,而对其他像素(如全局形状内的像素)则更多地依赖全局特征。这促使我们进一步提出一种新的基于解码器的 mixture-of-experts (MoE)。它将来自编码器的特征图视为 experts,并学习一个门控网络来选择一组合适的 experts 特征来预测每个像素的标签。该模型可以学习更加 specialized and disentangledexperts 特征图,减少推理过程中 experts 特征图之间的干扰。

3. 网络设计

3.1 Overall Architecture

Patcher: Patch Transformers with Mixture of Experts for Precise Medical Image Segmentation_第1张图片
给定一个大小为 H×W×C 的输入图像,Patcher 首先使用编码器从输入图像中提取特征。该编码器包含一个基于 Transformer 的 Patcher 块级联,它产生一个特征图序列,捕捉从局部到全局层次的视觉特征,receptive fields 不断增加。然后,这些特征图被输入到一个采用 mixture-of-experts (MoE) 的解码器中,其中编码器的每个特征映射都充当一个 expert。解码器中的门控网络输出 expert feature mapsweight maps,并使用权值得到组合特征图。然后使用多层感知器(MLP)和上采样层处理组合的特征图,得到最终的分割输出。基于 MoE 的设计增加了不同级别特征的特例化,同时减少了它们之间的干扰。它允许网络通过选择一组合适的 expert features 对每个像素进行预测。例如,网络可能需要特定全局形状内像素的全局特征,而它可能需要局部特征来捕捉分割边界的细节。最后,我们使用标准二值交叉熵(BCE)来训练 Patcher。

3.2 Patcher Block

Patcher: Patch Transformers with Mixture of Experts for Precise Medical Image Segmentation_第2张图片

Patcher 编码器内部的关键组件是 Patcher 块,这是一个通用的基于 Transformer 的模块,可以从不同空间尺寸的块输入中提取视觉特征。我们首先将输入沿着空间维度 H × W H × W H×W 分为 N h × N w N_h × N_w Nh×Nw 的大块网格。每个大块的尺寸为 L × L L × L L×L,其中 N h = H / L , N w = W / L N_h = H/L,N_w=W/L Nh=H/LNw=W/L。我们进一步用邻近块的 P 个像素填充每个大块的边,形成尺寸为 ( L + 2 P ) × ( L + 2 P ) (L+2P)×(L+2P) (L+2P)×(L+2P) 的大块。我们将带有重叠信息的带块沿着 batch 维度进行堆叠(batch 尺寸为 B = B 0 N h N w B=B_0N_hN_w B=B0NhNw),这样不同大小的大块在后续的操作中就不会互相干扰。每个大块定义了一块感受野,这类似于 CNN 中的卷积核,不同之处在于大块中的所有像素点共享相同的感受野。因此,padded context 是十分重要的,因为她增大了像素的感受野,这对块的边界处的像素点尤为重要。接下来,我们将堆叠在一起的大块进一步划分为 M h × M w M_h×M_w Mh×Mw 的小块网络,每个小块的大小为 S × S S × S S×S。与前面的工作类似,我们将每个小块内的所有像素点线性嵌入到一个 token 中,所有小块的 token 形成一个序列。然后使用 N v N_v Nv Vision Transformer blocks 对序列进行处理,对 patch 之间的关系进行建模,提取有用的视觉特征。受 SegFormer 的启发,我们没有使用位置编码,而是在 Transformer 的 MLP 中混合卷积层来捕获空间关系。我们还使用了 SegFormer 中有效的自注意力机制来进一步降低计算成本。Transformer 块的输出特征图的空间尺寸为 M h × M w M_h × M_w Mh×Mw,batch 尺寸为 B。我们从特征图中选取中心的 K × K K × K K×K 区域,其中 K = L / S K = L / S K=L/S,这排除了 padded context,对应于实际大小的大 patch。然后我们根据大 patch 在原始图片中的位置将其进行重新组装,形成最终的输出( H S × W S \frac{H}{S} × \frac{W}{S} SH×SW)。

Patcher block 有两个重要的超参数:

  1. 大 patch 的大小 L L L,它定义了接受域,允许在局部或全局水平上进行特征提取;
  2. padded context 的大小 P P P,它控制有多少来自邻近大块的信息被使用;

3.3 Patcher Encoder

Patcher 编码器使用四个 Patcher 块级联生成四个空间维度递减和接收域递增的特征图。所有 block 的小 patch 尺寸 S S S 设置为2,即每个 block 之后空间尺寸减半。大 patch 尺寸 L L Lpadding context 尺寸 P P P 分别设置为32和8。通过对所有的 block 设置相同的 L L L P P P,我们允许更深的 Patcher block 有更大的接受域,从而逐渐将block的焦点从捕获局部特征转移到全局特征。这种类似于基于 CNN 的 U-Net 的编码器设计已经被证明是有效的。因此,Patcher 编码器结合了这两个优点:

  1. 受益于 Transformer 的空间关系建模能力;
  2. 受益于 CNN 的从粗到细的有效特征提取能力;

3.4 Mixture of Experts Decoder

解码器遵循 MoE 设计,它将编码器的四个特征图视为 experts。解码器首先使用逐像素的 MLP 处理每个特征图,然后将其上采样到第一个特征图的大小,即 H 2 × W 2 × D \frac{H}{2} × \frac{W}{2} × D 2H×2W×D,其中 D D D 为 MLP 之后的通道数。我们用 [ F 1 , F 2 , F 3 , F 4 ] [F_1, F_2, F_3, F_4] [F1,F2,F3,F4] 表示上采样的特征,这些特征也称为 experts features。然后,门控网络将 experts features 作为输入,生成 experts feature maps 的权值图 [ W 1 , W 2 , W 3 , W 4 ] [W_1, W_2, W_3, W_4] [W1,W2,W3,W4],每个权值图的大小为 H 2 × W 2 \frac{H}{2} × \frac{W}{2} 2H×2W。每个像素的权值图和为1,即 W 1 + W 2 + W 3 + W 4 = 1 W_1 +W_2 +W_3 +W_4 = 1 W1+W2+W3+W4=1。门控网络首先将所有的 experts feature maps 沿着通道连接起来,然后使用几个卷积层和一个最终的 softmax 层将连接起来的特征处理成权值图。然后我们使用权值图生成合并的特征图 O O O:
O = ∑ i = 1 4 W i ∗ F i O = \sum^4_{i=1}W_i*F_i O=i=14WiFi
合并后的特征图 O O O 通过另一个 MLP 预测 segmentation logits,然后上采样到原始图像大小。解码器的 MoE 设计使网络能够学习更专业的特征图,减少特征图之间的干扰。对于每个像素的预测,门控函数通过权衡全局和局部特征的重要性来选择一组合适的特征。

你可能感兴趣的:(transformer,人工智能)