论文地址:https://arxiv.org/pdf/2106.14385.pdf
源码地址:https://github.com/JiYuanFeng/MCTrans
摘要:最近的视觉transformer(即图像分类)学习不同patch标记的非局部注意相互作用。然而,现有技术错过了学习不同像素的跨尺度依赖性、不同标签的语义对应关系、特征表征和语义嵌入的一致性,这对生物医学分割至关重要。本文通过提出一种统一的transformer网络,称为多复合transformer(MCTrans)来解决上述问题,该网络将丰富的特征学习和语义结构挖掘集成到一个统一的框架中。特别地,MCTrans将多尺度卷积特征嵌入为一个标记序列,并在以前的工作中执行尺度内和尺度间的自我注意,而不是单尺度注意。此外,还引入了一种可学习的代理嵌入,分别利用自注意和交叉注意来建模语义关系和特征增强。MCTrans可以很容易地插入一个类似unet的网络,并在生物医学图像分割方面比最先进的方法有了显著的改进。例如,MCTrans在Pannuke,CVC-Clinic,CVC-Colon,Etis,Kavirs,ISIC2018数据集上分别比UNet高 3.64%,3.71%, 4.34%, 2.8%, 1.88%, 1.57%
解决的问题:MCTrans 克服了传统的vision transformer的局限性
这篇论文的主要贡献有:
模型架构 :
如图2所示,这是在经典的UNet编码器和解码器架构之间引入了MCTransformer,它由Transformer-Self-Attention(TSA)模块和Transformer-Cross-Attention(TCA)模块组成。引入前者来对多个特征之间的上下文信息进行编码,产生丰富而一致的像素级上下文。后者引入了可学习嵌入的语义关系建模,并进一步增强了特征表示。实际上,给定一个图像I,采用深度CNN提取不同尺度的多层次特征,对于层级i,特征以P×P的补丁大小展开,其中P设置为1,即第i特征图的每个位置将被视为“补丁”,总共有个补丁,接下来,将不同层的分割补丁传递到具有相同输出特征维度的单个投影头(即1×1卷积层),并获得嵌入的token,为了补偿缺失的位置信息,位置嵌入
被补充到token中,以提供关于特征在序列中的相对或绝对位置的信息,这可以表述为
。接下来,我们将token输入TSA模块,用于多尺度上下文建模。输出增强的token进一步通过TCA模块,并与代理嵌入进行交互,代理嵌入
M是数据集的类别数。最后,我们将编码的token折叠回金字塔特征,并以自下而上的方式合并,以获得用于预测的最终特征图。
TSA模块
以一维嵌入token T作为输入,利用TSA模块学习多尺度特征之间的像素级上下文依赖关系。
如图2所示,TSA模块由Ks层组成,每个层由多头自注意(MSA)和前馈网络(FFN)组成。
如图3所示,在每个块之前应用层标准化(LN),在每个块之后应用残余连接。FFN包含两个具有ReLU激活的线性层。对于第l层,对自我注意的输入是从输入Tl−1计算出的三元组(q,k,v),如:
是第l层不同线性投影头的参数矩阵,dq,dk,dv是三个输入的维度
SA可以用公式表示为:
MSA是一个具有h个独立的SA操作的扩展,并将它们的连接输出投影为:
是输出线性投影头的一个参数,本文采用h=8、C=128和dq、dk、dv等于C/h=32。如图3(a)所示,整个计算可以表示为:
为了简单起见,方程中省略了LN。需要注意的是,teken T(从多尺度特征中变平)具有非常长的序列长度,而MSA的二次计算复杂度使得其无法处理。为此,本模块中使用[Deformable detr: Deformable transformers for end-to-end object detection]中提出的可变形的自我注意(DSA)机制来代替SA。DSA作为依赖数据的稀疏注意,并不是全成对的,DSA只关注整个序列的稀疏元素集,这在很大程度上降低了计算复杂性,并允许多层次特征映射的交互。
TCA模块
如图2所示,除了增强的token ,还提出了一种可学习的代理嵌入来学习类别间的全局语义关系(即类内/类间)。与TSA模块一样,TCA模块由Kc层组成,但包含两个多头自我注意块。在实践中,对于第j层,代理嵌入由各种线性投影头进行转换,生成第一个MSA块的输入(q,k,v)。在这里,MSA块的自我注意机制与每一对类别进行连接和交互,从而建模不同标签的语义对应关系。接下来,学习到的代理嵌入通过在另一个MSA块中的交叉注意,提取并与输入标记的特征进行交互,其中查询输入是代理嵌入,键,值输入是token
通过交叉注意,token的特征与学习到的全局语义关系进行融合,全面提高了类内一致性和特征表示的类间可辨别性,产生了最新的代理嵌入。注意到过程两个MSA块的计算等于公式2
此外,我们还引入了一种辅助损失来促进代理嵌入学习。特别地,输出
对TCA模块最后一层的证明进一步传递给线性投影头,并产生多类预测
在GT分割掩模的基础上,我们找到了唯一的元素来计算分类标签的监督。这样,驱动代理嵌入学习适当的语义关系,有助于提高同一类别的特征相关性和不同类别之间的特征可辨别性。最后,将编码的标记 折叠回二维特征,并附加不涉及的特征来形成金字塔特征,我们以规则的自下而向上的风格逐步将它们与2×上采样层和3×3卷积合并,以获得分割的最终特征图。
实验
数据集:
设置
主要是在Panunke数据集上进行方法评估,以显示不同网络组件的有效性。最后,将MCTrans与所有数据集上的顶级方法进行了比较。我们报告了DSC的所有结果,得分越高,结果越好。我们采用传统的CNN主干网络,包括VGG-Style编码器和ResNet-34,来提取多尺度的特征表示。对于网络优化,我们使用CE loss损失和Dice loss来惩罚分割的训练误差和权重为0.1的CE loss。我们用简单的翻转来增强训练图像。我们使用初始学习率为3e-4的Adam优化器来训练网络。学习速率在训练过程中呈线性衰减。
消融实验:
网络组件的分析:
我们通过分割精度来评估MCTrans的核心模块的重要性。我们使用VGG风格的网络作为主干。与在Pannke数据集上达到64.92%的UNet基线相比,MCTrans使用TSA和TCA的能力达到了68.40%的准确率。在表1,通过将TSA模块添加到UNet中,性能提升到67.93%。
为了证明构建多尺度像素级依赖性的有效性,我们在UNet的最高层特征上使用了Non-local操作和Transformer-Encoder,以实现单尺度上下文传播,所产生的精度远远落后于我们的方法。我们进一步评估了TCA模块的影响。添加TCA后,学习到的语义先验帮助构建已识别的上下文依赖性,并将基线和MCTrans的得分分别提高到67.16%和68.40%。这说明了学习语义关系对增强特征表征的有效性。我们还研究了消除辅助损失的情况。在这里,我们只隐式地建模类别之间的语义关系。该策略可将性能降至67.87%。
设置的敏感性:我们改变了TSA和TCA模块的数量,并研究了其对分割精度的影响。首先逐步增加TSA模块的数量,以扩大建模能力。如表2,我们可以看到,当TSA的大小增加时,DSC评分首先增加,然后减小。在固定Ns后,我们进一步插入TCA并扩大其尺寸
我们还发现它在Nc=4达到顶部,然后减少。这间接地表明,当在一个小的数据集上进行训练时,基于transformer的模型的容量并没有那么好。
与最先进的方法的比较:
在表3中,我们将MCTrans与Panunke数据集上最先进的方法进行了比较。在第一组中,我们采用传统的VGG-Style网络作为特征提取器。与其他建模机制相比,我们的MCTrans通过跨多层特征探索像素级依赖关系,实现了显著的改进。为了进行更全面的比较,在第二组中,我们采用了一个更强的特征提取器(如ResNet-34)。同样,我们取得了比其他方法更好的准确性。我们在图4中提供了分割结果的例子。
在表4中,我们还分别报告了五种病变分割的结果。我们的方法的结果仍然显著优于其他顶级方法。这些结果证明了所提出的MCTrans在各种分割任务上的通用性。
我们提供了关于计算开销的更多细节(即每秒浮点运算(Flops)和参数的数量)。如表3所示。MCTrans以合理的计算开销为代价,取得了更好的结果。与UNet基线相比,参数几乎相同且计算量略有增加的MCTrans实现了3.64%的显著改善。请注意,其他顶级方法,如UNet++,在许多计算方面超过了MCTrans,同时产生了更低的性能。
结论:
在本文中,我们提出了一种强大的基于transformer的医学图像分割网络。我们的方法通过强大的注意机制结合了丰富的上下文建模和语义关系挖掘,有效地解决了跨尺度依赖性、不同类别的语义对应关系等问题。我们的方法是有效的,并且在几个公共数据集上优于最先进的方法,如TransUnet。