标题:Multi-compound Transformer for Accurate Biomedical Image Segmentation
作者:Yuanfeng Ji,香港大学;Ping Luo,商汤科技
来源:MICCAI 2021
代码:https://github.com/JiYuanFeng/MCTrans MCTrans-master.zip
主题:Transformer;注意力机制;医学图像分割
由于卷积核函数的局部性,传统的基于 CNN 的分割模型缺乏对长期依赖关系的建模能力。
Due to the local property of the convolutional kernels, the traditional CNN-based segmentation models (e.g. FCN) lack the ability for modeling long-term dependencies.
为了解决上述问题,一些方法已经被用于进行强大的关系建模:
基于空间金字塔的方法
基于 UNet 的编码器-解码器网络
TransUNet
However, such a design is still sub-optimal for medical image segmentation for the following reasons. First, it only uses the self-attention mechanism for context modeling on a single scale but ignores the cross-scale dependency and consistency. The latter usually plays a critical role in the segmentation of lesions with dramatic size changes. Second, beyond the context modeling, how to learn the correlation between different semantic categories and how to ensure the feature consistency of the same category region are still not taken into account. But both of them have become critical for CNN-based segmentation scheme design.
本文提出了 Multi-Compound Transformer (MCTrans) 网络,它构建了跨尺度的上下文依赖关系,并挖掘了语义关系,用于准确的生物医学图像分割;
引入了 Transformer-Self-Attention (TSA) 模块,通过自注意力机制实现跨尺度像素级上下文建模,从而实现不同尺度上的更全面的特征增强。
开发了 Transformer-Cross-Attention (TCA) 模块,通过引入一种可学习的代理嵌入来自动学习不同语义类别的语义对应。然后进一步将这种代理嵌入通过交叉注意力机制与特征表示进行交互。通过为代理嵌入引入辅助损失,它可以有效地提高同一类别的特征相关性和不同类别之间的特征可辨别性。
We propose the Multi-Compound Transformer (MCTrans), which incorporates rich context modeling and semantic relationship mining for accurate biomedical image segmentation. MCTrans overcomes the limitations of conventional vision transformers by: (1) introducing the Transformer-Self-Attention (TSA) module to achieve cross-scale pixel-level contextual modeling via the self-attention mechanisms, leading to a more comprehensive feature enhancement for different scales. (2) developing the Transformer-Cross-Attention (TCA) to automatically learn the semantic correspondence of different semantic categories by introducing the proxy embedding. We further use such proxy embedding to interact with the feature representations via the cross-attention mechanism. By introducing auxiliary loss for the updated proxy embedding, we find that it could effectively improve feature correlations of the same category and the feature discriminability between different classes.
如上图所示,本文在经典的 UNet 编码器和解码器架构之间引入了 MCTransformer,它由 Transformer-Self-Attention (TSA) 模块和 Transformer-Cross-Attention (TCA) 模块组成。前者用于对多个特征之间的上下文信息进行编码,从而产生丰富且一致的像素级上下文;后者引入了可学习的嵌入(embedding),为了语义关系建模并进一步增强特征表示。
The former is introduced to encode the contextual information between the multiple features, yielding rich and consistent pixel-level context. And the latter introduces learnable embedding for semantic relationship modeling and further enhances feature representations.
给定图像 I ∈ R H × W I \in \mathbb{R}^{H \times W} I∈RH×W,采用深度 CNN 提取不同尺度的多级特征 { X i ∈ R H 2 i × W 2 i × C i } \left\{X_{i} \in \mathbb{R}^{\frac{H}{2^{i}} \times \frac{W}{2^{i}} \times C_{i}}\right\} {Xi∈R2iH×2iW×Ci}。对于层级 i i i,特征以 P × P P \times P P×P 的大小展开成各个块(patch),其中 P P P 在本文中设置为 1,即第 i i i 个特征图的每个位置都将被视为一个块,得到 L i = H W 2 2 ∗ i × P 2 L_{i}=\frac{H W}{2^{2 * i} \times P^{2}} Li=22∗i×P2HW 个块。接下来,将不同层级中的块输入具有相同输出特征维度 C C C 的线性投影头(linear projections heads)(即 1×1 卷积层),得到嵌入标记(token) T i ∈ R L i × C T_{i} \in \mathbb{R}^{L_{i} \times C} Ti∈RLi×C。然后,我们将 i i i = 2, 3, 4 层级的特征进行拼接,形成整体标记 T ∈ R L × C T \in \mathbb{R}^{L \times C} T∈RL×C,其中 L = ∑ i = 2 4 L i L=\sum_{i=2}^{4} L_{i} L=∑i=24Li。为了弥补丢失的位置信息,位置嵌入(positional embedding) E p o s ∈ R L × C E_{p o s} \in \mathbb{R}^{L \times C} Epos∈RL×C 被添加到标记中,以提供关于特征在序列中的相对或绝对位置的信息,这样标记可以表示为 T = T + E pos T=T+E_{\text {pos }} T=T+Epos 。接下来,我们将标记 T T T 输入 TSA 模块以进行多尺度上下文建模。将输出的增强后的标记进一步输入 TCA 模块并与代理嵌入(proxy embedding) E pro ∈ R M × C E_{\text {pro }} \in \mathbb{R}^{M \times C} Epro ∈RM×C 进行交互,其中 M M M 是数据集的类别数。最后,我们将编码后的标记折叠回金字塔特征图,并以自下而上的方式合并它们,以获得最终的特征图进行预测。
将一维嵌入标记 T T T 作为输入,TSA 模块用于学习多尺度特征之间的像素级上下文依赖关系。如上图所示,TSA 模块由 K s K_{s} Ks 层组成,每层由多头自注意力(multi-head self-attention,MSA)和前馈网络(feed forward networks,FFN)组成,在每个块之前应用层归一化(layer normalization,LN) ,在每个块之后应用残差连接(residual connection)。 FFN 包含两个带有 ReLU 激活的线性层。
对于第 i i i 层,多头自注意力的输入是从输入 T l − 1 T^{l-1} Tl−1 计算得到的元组**(query, key, value)**:
query = T l − 1 W Q l , key = T l − 1 W K l , value = T l − 1 W V l \text { query }=T^{l-1} \mathbf{W}_{Q}^{l}, \text { key }=T^{l-1} \mathbf{W}_{K}^{l}, \text { value }=T^{l-1} \mathbf{W}_{V}^{l} query =Tl−1WQl, key =Tl−1WKl, value =Tl−1WVl
其中 W O l ∈ R C × d q \mathbf{W}_{O}^{l} \in \mathbb{R}^{C \times d_{q}} WOl∈RC×dq, W K l ∈ R C × d k \mathbf{W}_{K}^{l} \in \mathbb{R}^{C \times d_{k}} WKl∈RC×dk, W V l ∈ R C × d v \mathbf{W}_{V}^{l} \in \mathbb{R}^{C \times d_{v}} WVl∈RC×dv 是第 i i i 层不同线性投影头的参数矩阵, d q d_{q} dq, d k d_{k} dk, d v d_{v} dv 是三个输入的维度。
**自注意力(SA)**可以表示为:
S A ( T l − 1 ) = T l − 1 + Softmax ( T l − 1 W Q l ( T l − 1 W K l ) ⊤ d k ) ( T l − 1 W V l ) \mathrm{SA}\left(T^{l-1}\right)=T^{l-1}+\operatorname{Softmax}\left(\frac{T^{l-1} \mathbf{W}_{Q}^{l}\left(T^{l-1} \mathbf{W}_{K}^{l}\right)^{\top}}{\sqrt{d_{k}}}\right)\left(T^{l-1} \mathbf{W}_{V}^{l}\right) SA(Tl−1)=Tl−1+Softmax(dkTl−1WQl(Tl−1WKl)⊤)(Tl−1WVl)
**多头自注意力(MSA)**是具有 h h h 个独立 SA 操作的扩展,将它们的输出进行拼接和投影:
MSA ( T l − 1 ) = Concat ( S A 1 , … , S A h ) W O l \operatorname{MSA}\left(T^{l-1}\right)=\operatorname{Concat}\left(\mathrm{SA}_{1}, \ldots, \mathrm{SA}_{\mathrm{h}}\right) W_{O}^{l} MSA(Tl−1)=Concat(SA1,…,SAh)WOl
其中 W O ∈ R h d k × C \mathbf{W}_{O} \in \mathbb{R}^{h d_{k} \times C} WO∈Rhdk×C 是输出线性投影头的参数矩阵。
在本文中,我们采用 h h h = 8, C C C = 128 , d q d_{q} dq, d k d_{k} dk, d v d_{v} dv 等于 C / h C/h C/h = 32。
Transformer-Self-Attention 整个计算过程可以表示为:
T l = MSA ( T l − 1 ) + FFN ( MSA ( T l − 1 ) ) ∈ R L × C T^{l}=\operatorname{MSA}\left(T^{l-1}\right)+\operatorname{FFN}\left(\operatorname{MSA}\left(T^{l-1}\right)\right) \in \mathbb{R}^{L \times C} Tl=MSA(Tl−1)+FFN(MSA(Tl−1))∈RL×C
为简单起见,在上式中省略了层归一化 LN。
需要注意的是,由多尺度特征展开得到的标记 T T T 的序列长度极长,MSA 的二次方计算复杂度使其无法处理。为此,在这个模块中,我们用**可变形自注意力(Deformable Self Attention,DSA)**机制来代替 SA。作为依赖于数据的稀疏注意力(不是全成对的),DSA 只关注整个序列中的一个稀疏元素集,而不管其序列长度如何,这在很大程度上降低了计算复杂度并允许多级特征图的交互。更多关于 Deformable Self Attention 的细节请参考 Deformable DETR: Deformable transformers for end-to-end object detection。
如上图所示,除了增强的标记 T K S T^{K_{S}} TKS 之外,我们还提出了一种可学习的代理嵌入(proxy embedding) E p r o E_{pro} Epro 来学习类别之间(即类内/类间)的全局语义关系。在结构上,代理嵌入是具有特定维度(128)的张量。
与 TSA 模块一样,TCA 模块由 K c K_{c} Kc 层组成,但包含两个多头自注意力块(multi-head self-attention blocks)。
对于第 j j j 层,代理嵌入 E p r o j − 1 E_{ pro }^{j-1} Eproj−1 由不同的线性投影转换成第一个 MSA 块的输入(query, key, value)。在这里,MSA 块的自注意力机制与每一对类别连接并交互,从而对各种类别的语义对应进行建模。
Here, the MSA block’s self-attention mechanism connects and interacts with each pair of categories, thus modeling the semantic correspondence of various labels.
接下来,学习到的代理嵌入通过另一个 MSA 块中的交叉注意力机制提取输入标记 T K S T^{K_{S}} TKS 的特征并与之交互,其中输入的 query 是代理嵌入,输入的 key 和 value 是标记 T K S T^{K_{S}} TKS。通过交叉注意力机制,标记的特征与学习到的全局语义关系进行通信,全面提高了特征表示的类内一致性和类间可区分性,产生了更新后的代理嵌入 E p r o j E_{ pro }^{j} Eproj。
Through the cross-attention, the features of tokens communicate with the learned global semantic relationship, comprehensively improving intra-class consistency and the inter-class discriminability of feature representation, yielding updated proxy embedding E p r o j E_{ pro }^{j} Eproj.
此外,我们引入了辅助损失 L o s s a u x Loss_{aux} Lossaux 来促进代理嵌入的学习。TCA 模块的最后一层的输出 E p r o K c E_{ pro }^{K_{c}} EproKc 被进一步传递给线性投影头,并产生多类预测 P r e d a u x ∈ R M Pred_{aux} \in \mathbb{R}^{M} Predaux∈RM。基于真实分割掩膜,可以找到计算分类标签以进行监督的独特元素#疑问#。通过这种方式,代理嵌入被驱使学习适当的语义关系,并有助于提高同一类别的特征相关性和不同类别之间的特征可辨别性。
Base on the ground-truth segmentation mask, we find the unique elements to compute classification labels for supervision. In this way, the proxy embedding is driven to learn appropriate semantic relationship, and help to improve feature correlations of the same category and the feature discriminability between different categories.
最后,将编码后的标记 T K s T^{K_{s}} TKs 折叠回二维特征图并附加未处理的特征以形成金字塔特征 { X 0 , X 1 , X 2 ′ , X 3 ′ , X 4 ′ } \left\{X_{0}, X_{1}, X_{2}^{\prime}, X_{3}^{\prime}, X_{4}^{\prime}\right\} {X0,X1,X2′,X3′,X4′}。我们使用 2 倍上采样层和 3×3 卷积以常规自下而上的方式逐步合并它们,以获得用于分割的最终特征图。
本文提出的 MCTrans 在三种类型的六个分割数据集上进行了评估。每个数据集都有不同的数据模式、数据大小和前景类,使其适合用来评估 MCTrans 的有效性和泛化性。
(1) 细胞分割:Pannuke 数据集(7904 例,6 类);
(2)息肉分割: CVC-Clinic 数据集(612 例,2 类),CVC-ColonDB 数据集(380 例,2 类),ETIS-Larib 数据集(196 例,2 类),Kvasir 数据集(1000 例,2 类);
(3)皮肤病变分割:ISIC2018 数据集(2594 例,2 类)。
我们采用 VGG 网络作为主干网络(backbone)来提取多尺度特征表示。为了证明构建多尺度像素级依赖关系的有效性,我们在 UNet 的最高级别特征上使用非局部操作和 Transformer-Encoder 来实现单尺度的上下文传播,结果的精度远远落后于我们的方法。添加 TCA 后, MCTrans 的得分提高到了 68.40%,这表明学习语义关系以增强特征表示的有效性。去除辅助损失后,网络只对类别之间的语义关系进行隐式建模。该策略将性能降低到了 67.87%。
我们在 Pannuke 数据集上比较了 MCTrans 和最先进的方法。我们采用传统 VGG 网络作为特征提取器。为了进行更全面的比较,在第二组中,我们采用了更强的特征提取器(例如 ResNet-34)。在下表中,MCTrans 以合理的计算开销为代价获得了更好的结果。与 UNet 相比,MCTrans 以几乎相同的参数和略微增加的计算量实现了 3.64% 的显著改善。请注意,其他先进方法(例如 UNet++)的计算量比 MCTrans 多很多,但性能较低。
下图提供了一些分割结果的示例。
可变形自注意力(Deformable Self Attention,DSA)
Zhu, Xizhou, et al. “Deformable DETR: Deformable transformers for end-to-end object detection.” arXiv preprint arXiv:2010.04159 (2020).
官方评审意见 :https://miccai2021.org/openaccess/paperlinks/2021/09/01/319-Paper0786.html
评分:6,7,8