Medical Transformer: Gated Axial-Attention for Medical Image Segmentation
医学转换器:用于医学图像分割的门控轴向注意力
在过去的十年中,深度卷积神经网络被广泛地应用于医学图像分割,并显示出良好的性能。然而,由于卷积体系结构中存在固有的归纳偏差,它们缺乏对图像中的长期依赖关系的理解。最近提出的基于Transformer的体系结构利用自我注意机制,编码长范围依赖关系,并学习具有高度表现力的表示法。这促使我们探索基于Transformer的解决方案,并研究将基于Transformer的网络架构用于医学图像分割任务的可行性。大多数现有的基于Transformer的视觉应用网络体系结构都需要大规模的数据集来进行适当的训练。然而,与用于视觉应用的数据集相比,用于医学成像的数据样本数量相对较少,这使得有效地训练用于医疗应用的变压器变得困难。为此,我们提出了一种门控轴向注意模型,通过在自我注意模块中引入额外的控制机制来扩展现有的体系结构。此外,为了对模型进行有效的医学图像训练,我们提出了一种局部-全局训练策略(LOGO),进一步提高了模型的性能。具体地说,我们对整个图像和补丁进行操作,分别学习全局特征和局部特征。在三个不同的医学图像分割数据集上对所提出的医疗转换器(MedT)进行了评估,结果表明它比卷积和其他相关的基于Transformer的架构具有更好的性能。代码:https://github.com/jeya-mariajose/Medical-Transformer
开发自动、准确和稳健的医学图像分割技术一直是医学成像领域的主要问题之一,因为它对于计算机辅助诊断和图像引导手术系统至关重要。从医学扫描中分割器官或病变有助于临床医生做出准确的诊断,计划手术过程,并提出治疗策略。早期的医学分割方法使用统计形状模型、基于轮廓的方法和基于机器学习的方法[41,32,9]。随着深度卷积神经网络(ConvNets)在计算机视觉领域的普及[14,8,1,28],ConvNets很快被用于医学图像分割任务。像U-Net[25]、V-Net[19]、3D U-Net[5]、RES-UNET[40]、DenseUNet[17]、Y-Net[18]、U-Net++[44]、Kiu-Net[34,33]和U-Net3+[11]这样的网络已经被特别提出用于执行各种医学成像模式的图像和体积分割。这些方法在许多困难的数据集上也取得了令人印象深刻的性能,证明了ConvNets在学习区分特征以从医学扫描中分割器官或病变方面的有效性。
ConvNets是目前提出的大多数图像分割方法的基本构件,无论是医学成像还是计算机视觉任务。然而,ConvNets缺乏对图像中存在的远程依赖关系进行建模的能力。更准确地说,在ConvNets中,每个卷积核只关注整个图像中像素的局部子集,并迫使网络关注局部模式而不是全局上下文。已经有一些工作集中于使用图像金字塔[42]、Atrus卷积[4]和注意机制[12]来建模ConvNet的长期依赖关系。然而,值得注意的是,由于以前的大多数方法对于医学图像分割任务并没有关注这一方面,因此在对远程依赖关系建模方面仍有改进的余地。
为了首先了解为什么远程依赖关系对医学图像很重要,我们可视化了一个早产儿的超声扫描示例,并根据图1中的扫描对脑室进行分割预测。对于提供有效分割的网络,它应该能够理解哪些像素对应于掩膜,哪些像素对应于背景。在给定单个像素的情况下,网络需要了解它是更接近背景的像素还是更接近分割掩膜的像素。由于图像的背景是分散的,学习对应于背景的像素之间的长范围相关性可以帮助网络防止将像素误分类为掩码导致真负片的减少(将0视为背景,将1视为分割掩码)。类似地,每当分割掩码较大时,了解与该掩码对应的像素之间的长期相关性也有助于做出有效的预测。在图1(B)和©中,我们可以看到卷积网络错误地将背景分类为脑室,而所提出的基于Transformer的方法没有犯这种错误。这发生在我们提出的方法学习像素区域与背景区域的长期相关性的时候。
图1 (a)活体早产儿脑室输入超声。(b)U-Net、©Res-UNET、(d)MedT和(e)Ground Truth的预测。红色方框突出显示了由于缺乏学习的远程依赖关系而被基于ConvNet的方法遗漏分类的区域。这里的ground truth是由一位专业的临床医生分割出来的。虽然它显示脑室内有一些出血,但它与分割区域并不相符。基于Transformer的模型可以正确捕获此信息。
像BERT[6]和GPT[23,24,2]这样的转换器模型最近彻底改变了大多数自然语言处理(NLP)任务,如机器翻译[21]、问题回答[26]和文档分类[22]。Transformer成功的主要原因是它们能够学习输入令牌之间的远程依赖关系。这是可能的,因为自我注意机制可以找到输入中每个令牌之间的依赖关系。随着Transformer在自然语言处理应用中的普及,它们最近也被用于计算机视觉应用。具体地说,Vision Transformers(VIT)[7]成功地使用了位置嵌入的2D图像补丁作为输入序列,类似于语言转换器模型的输入序列。在大型图像数据集上进行预训练时,该算法在图像识别方面取得了与卷积网络相当的性能。数据高效的图像转换器(Deit)[31]被提出,它展示了如何将转换器应用于中等大小的数据集。关于用于分割任务的转换器,Axial-Deeplab[37]利用了轴向注意模块[10],该模块将2D自我注意分解为两个1D自我注意,并引入位置敏感型轴向注意设计用于全景分割。紧随其后的是Max-Deeplab[36],它使用掩模Transformer以端到端的方式解决全景分割问题。在分割Transformer(SETR)[43]中,使用Transformer作为编码器输入一系列图像块,使用ConvNet作为解码器,从而得到一个强大的分割模型。
在医学图像分割中,基于Transformer模型的研究还不多见。最接近的作品是那些使用注意力机制来提高表现的作品[29,20,39]。然而,这些网络的编码器和解码器仍然以卷积层作为主要的构建块。最近,TransUNet[3]被提出,它使用基于Transformer的编码器对图像块序列进行操作,并使用带有跳过连接的卷积解码器来分割医学图像。由于TransUNet受到VIT的启发,它仍然依赖于通过在大型图像语料库上训练而获得的预先训练的权重。与这些工作不同的是,我们探索了使用只使用自我注意机制的变压器作为医学图像分割的编码器的可行性,而不需要任何预训练。
我们观察到,基于Transformer的模型只有在大规模数据集上进行训练时才能很好地工作[7]。当将Transformer用于医学成像任务时,这变得有问题,因为可用于在任何医学数据集中训练的具有对应标签的图像的数量相对稀少。制作标签的过程也很昂贵,需要专业知识。具体地说,使用较少的图像进行训练会导致难以学习图像的位置编码。为此,我们提出了一种门控位置敏感轴向注意机制,其中我们引入了四个门来控制位置嵌入提供给键、查询和值的信息量。这些门是可学习的参数,这使得所提出的机制可以应用于任何大小的数据集。根据数据集的大小,这些门将了解图像的数量是否足以学习正确的位置嵌入。根据位置嵌入学习到的信息是否有用,门参数要么收敛到0,要么收敛到更高的值。此外,我们还提出了一种局部-全局(LOGO)训练策略,在该策略中,我们使用一个浅全局分支和一个深局部分支对医学图像的patch进行操作。这一策略提高了分割性能,因为我们不仅对整个图像进行操作,而且将重点放在局部补丁中存在的更精细的细节上。最后,我们提出了医疗变压器(MedT),它以我们的门控位置敏感轴向注意力为基础,采用我们的标识训练策略。
综上所述,本文**(1)提出了一种适用于较小数据集的门控位置敏感轴向注意机制,(2)引入了有效的局部-全局(LOGO)训练方法,(3)提出了基于上述两个专门用于医学图像分割的概念的医学转换器(MedT),以及(4)**成功地提高了卷积网络和三种不同数据集上的完全注意力结构上的医学图像分割任务的性能。
Medical Transformer(MedT)使用门控轴向注意层作为基本构建块,并使用LOGO策略进行培训。MedT有两个分支机构-一个全局分支机构和一个本地分支机构。这两个分支的输入是从初始卷积块提取的特征地图。该块有3个conv层,每个conv层后面都有批归一化和REU激活。在两个分支的编码器中,我们使用我们提出的Transformer层,而在解码器中,我们使用conv块。编码器瓶颈包括1×1卷积层和两层多头注意层,其中一层沿高度轴操作,另一层沿宽轴操作。每个多头注意块由提出的门控轴向注意层组成。请注意,每个多头注意块具有8个门控的轴向注意头。多头注意块的输出被连接并通过另一个1×1 conv,这1x1 conv被添加到残差输入图中以产生输出注意图。图2(b)给出了建议的编码块的概述。在每个解码器块中,我们有一个卷积层,其后是一个上采样层和RELU激活。我们还在两个分支中的每个编码器和解码器块之间有跳过连接。
图2 (a)使用LoGo策略进行培训的MedT的主要架构图。(b)MedT中使用的gated axial transformer层。©gated axial attention layer,它是在门控轴向transformer层中的高度和宽度gated multi-head attention blocks的基本构件。
**在MedT的全局分支中,我们有2个编码器块和2个解码器块。在本地分支中,我们有5个编码器块和5个解码器块。**这些数字是在实验分析的基础上确定的,在评估研究期间没有改变,可以在补充文件中找到。MedT的整体架构如图2(a)所示。在接下来的内容中,我们将详细讨论MedT的每个组件。
让我们考虑具有高度H、权重W和通道Cin的输入特征映射x∈R^ (Cin×H×w)。借助投影输入,使用以下公式计算自我注意层的输出y∈R^ (Cout×H×W):
其中查询q=Wqx、键k=Wkx和值v=Wvx都是从输入x计算的投影。这里,qij、kij、vij表示查询、键和值在任意位置i∈{1,……,H}和j∈{1,……W}。投影矩阵Wq,Wk,Wv∈R^ (Cin×Cout)是可学习的。如公式1中所示,使用基于Softmax(q^Tk)计算的全局亲和度将值v汇集在一起。因此,与卷积不同,自我注意机制能够从整个特征映射中捕获非本地信息。然而,计算这种亲和力是非常昂贵的,并且随着特征图大小的增加,将自我注意用于视觉模型体系结构通常变得不可行。此外,与卷积层不同,自我注意层在计算非局部上下文时不利用任何位置信息。位置信息在视觉模型中通常用于捕捉对象的结构。
Axial-Attention
为了克服计算亲和力的计算复杂性,将自我注意分解为两个自我注意模块。第一个模块在特征图高度轴上执行自我关注,第二个模块在宽度轴上进行操作。这被称为轴向注意[10]。因此,在高度轴和宽度轴上施加的轴向注意力有效地模拟了原始的自我注意机制,并具有更好的计算效率。为了在通过自注意力机制计算亲和力时增加位置偏置,添加了位置偏置项以使亲和力对位置信息敏感[27]。该偏置项通常被称为相对位置编码。这些位置编码通常可以通过训练来学习,并且已经显示出具有编码图像空间结构的能力。 在[37]中Wang等结合轴向注意力机制和位置编码,提出了一种基于注意力的图像分割模型。另外,与以前的注意力模型不一样,后者仅将相对位置编码用于查询,而Wang等人则提议将其用于所有查询,键和值。查询,键和值中的这种附加位置偏置显示为捕获了具有精确位置信息的远程交互[37]。对于任何给定的输入特征图x,带有位置编码和宽度轴的更新的自我关注机制可以写成:
公式2中的等式遵循文献[37]中提出的注意模型和r^q 、r ^k 、r^v ∈R^ (W×W)宽度轴向注意模型。请注意,公式2描述了沿张量宽度轴施加的轴向注意。一种类似的公式也被用来沿高度轴施加轴向注意,它们一起形成了一个计算效率很高的自我注意模型。
我们讨论了使用[37]中提出的轴向注意机制进行视觉识别的好处。具体地说,[37]中提出的轴注意能够以良好的计算效率计算非局部上下文,能够将位置偏差编码到机制中,并且能够在输入特征映射内编码远程交互。然而,他们的模型是在大规模分割数据集上进行评估的,因此轴注意更容易学习到键、查询和值的位置偏差。我们认为,对于小规模数据集的实验,就像医学图像分割中经常出现的情况一样,位置偏差很难学习,因此在编码远程交互时并不总是准确的。在学习的相对位置编码不够准确的情况下,将它们添加到相应的键、查询和值张量将导致性能降低。因此,我们提出了一种改进的轴向注意块,它可以控制位置偏差对非局部语境编码的影响。通过建议的修改,应用于宽度轴上的自我注意机制可以正式地写成:
其中自我注意公式紧跟着等式2增加了门控机制。此外,GQ、GK、GV1、GV2∈R是可学习的参数,它们共同创建门控机制,控制学习的相对位置编码对非局部语境编码的影响。通常,如果准确地学习了相对位置编码,则与未准确学习的编码相比,门控机制将为其分配较高的权重。图2©说明了典型门控轴向关注层中的前馈。
显然,在patch上训练transformer速度更快,也有助于提取更精细的图像细节。然而,对于医学图像分割这样的任务,仅靠patch训练是不够的。分割掩码很可能会大于patch大小。这限制了网络学习patch间像素的任何信息或依赖性。为了提高对图像的整体理解,我们建议使用网络的两个分支,即处理图像原始分辨率的全局分支和处理图像块的局部分支。
在全局分支中,我们减少了gated axial transformer layers的数量,因为我们发现所提出的transformer模型的前几个块足以模拟长距离依赖关系。在局部分支中,我们创建大小为I/4×I/4的16个图像块,其中I是原始图像的尺寸。在局部分支中,每个patch通过网络进行前向反馈,并根据其位置对输出特征图进行重新采样,以获得输出特征图。然后将两个分支的输出特征图相加并通过1×1卷积层以产生输出分割掩码。这种在图像的全局上下文上操作较浅的模型和在patch上操作较深的模型的策略提高了性能,因为全局分支关注的是高层信息,而局部分支关注的是更精细的细节。
我们使用脑解剖分割(超声)[38,35]、腺体分割(显微)[30]和MoNuSeg(显微)[15,16]数据集来评估我们的方法。脑解剖分割数据集有1300个二维超声扫描用于训练,329个用于测试。腺体分割数据集(GLAS)有85幅图像用于训练,80幅图像用于测试。MoNuSeg数据集有30个用于训练的图像和14个用于测试的图像。有关数据集的更多详细信息可在补充文件中找到。
MedT在Pytorch中实现。我们使用预测和ground truth之间的二进制交叉熵(CE)损失来训练我们的网络。请注意,我们也使用相同的损失函数来训练所有基线。CE损失定义如下
其中w和h是图像的尺寸,p(x,y)对应于图像中的像素,ˆp(x,y)表示在特定位置(x,y)的输出预测。我们使用批大小为4的ADAM优化器[13]和0.001的学习率进行实验。该网络针对400个epoch进行了训练。在训练门控轴向注意层时,我们不会激活前10个epoch的门控训练。我们所有的实验都使用了NVIDIA Quadro 8000图形处理器。
对于基线比较,我们首先对卷积方法和基于transformer的方法进行实验。对于卷积基线,我们与全卷积网络(FCN)[1]、U-NET[25]、U-NET++[44]和Res-UNET[40]进行了比较。对于基于transformer的基线,我们使用具有来自[37]灵感的剩余连接的轴向注意U网。对于我们提出的方法,我们使用所有个人贡献进行实验。在门控轴向注意力网络中,我们使用了轴向注意力U网,将其所有的轴向关注层替换为所提出的门控轴向关注层。在LOGO中,我们对轴向关注U网进行局部全局训练,而不使用门控轴向关注层。在MedT中,我们使用门控轴向注意作为全局分支和轴向注意的基本构件,而不对局部分支进行位置编码。
为了进行定量分析,我们使用F1和IOU分数来将我们提出的方法与基线进行比较。量化结果如表1所示。值得注意的是,对于具有相对较多图像的数据集(如大脑US),基于完全注意力(Transformaer)的基线比卷积基线表现更好。对于GLAS和MoNuSeg数据集,卷积基线比完全注意力基线表现更好,因为很难用更少的数据训练完全注意力模型[7]。在门控轴向注意和LOGO的帮助下,所提出的方法能够克服这一问题,其各自的性能都优于其他方法。我们最终的架构MedT比门控轴向注意、LoGo和所有以前的方法执行得更好。在Brain US、GLAS和MoNuSeg数据集上,比完全注意力基线分别提高了0.92%、4.76%和2.72%。对最佳卷积基线的改善分别为1.32%、2.19%和0.06%。所有这些值都是根据F1分数计算的。
对于定性分析,我们可视化了来自U-Net[25]、ResUNet[40]、Axial Attribute U-Net[37]和我们在图3中提出的方法MedT的预测。可以看出,MedT的预测非常好地捕捉到了长期依赖关系。例如,在图3的第二行中,我们可以观察到红色框上突出显示的小分割掩模在所有卷积基线中都没有被检测到。然而,由于完全注意力模型编码了长范围依赖关系,它能够很好地学习分割,这要归功于编码的全局上下文。在第一行和第四行中,其他方法在突出显示的区域进行错误预测,因为这些像素非常接近分割掩膜。由于我们的方法考虑了用门控机制编码的像素依赖关系,因此它能够比轴向注意Unet更好地学习这些依赖关系。这使得我们的预测更加精确,因为它们不会对分割遮罩附近的像素进行误分类。在补充中可以找到更多这样的定性例子。
Ablation Study
在消融研究中,我们在所有的实验中都使用了大脑的美国数据。我们首先从一个标准的U网开始。然后,我们将剩余连接添加到U-Net,使其成为Res-UNET。现在,我们用轴向关注层替换了RES-UNET编码器中的所有卷积层。此配置是受[37]启发的轴向注意UNET。请注意,在此配置中,我们在前端有一个额外的卷积块,用于特征提取。接下来,我们用门控轴向注意层替换前一配置中的所有轴向注意层。这种配置被表示为门控轴向注意。然后,我们仅使用LOGO策略中的全局分支和局部分支进行实验。这表明在全局分支中只使用2层就足以获得不错的性能。此配置中的本地分支将在从图像中提取的patch程序上进行测试。然后,我们将这两个分支结合起来,以端到端的方式训练网络,表示为LoGo。注意,在该配置中,所使用的关注层仅是轴向关注层[37]。最后,我们将LOGO中的轴向关注层替换为门控轴向关注层,从而导致MedT。消融研究表明,MedT的每个单独组件都为提高性能做出了有益的贡献。
Number of Parameters
虽然MedT是一个多分支网络,但是我们在全局分支中只使用两层编码器和解码器,并使局部分支只对图像块进行操作,从而减少了参数的数量。此外,所提出的门控轴向注意块仅向该层增加了4个可学习参数。在表3中,我们将参数数量与其他方法进行了比较。U-Net对应于根据[25]的原始实现。U-Net(Mod)对应于过滤器数量减少的U-Net配置,以便与MedT中的参数数量相匹配。类似地,RES-UNET和RES-UNET(Mod)通过调整过滤器的数量,对应于具有更多和更少参数的配置。我们这样做是为了表明,即使参数数量更多,基线在性能方面也没有超过MedT,这表明改进并不是由于参数数量的轻微变化。
在这项工作中,我们探索使用基于transformer的编码器架构来分割医学图像,而不需要任何预训练。我们提出了一种门控的轴向注意力层,作为网络编码器多头注意力模型的构建块。我们还提出了一种LoGo训练策略,在这种策略中,我们使用相同的网络架构训练整张图像和图像patch。全局分支通过对长期依赖关系进行建模来帮助网络中学习高级功能,而本地分支则通过对patch进行操作来关注更精细的功能。在此基础上,提出了以轴向关注度为主要构建块的MedT(Medical Transformer)作为编码器的主要构建块,并采用LOGO策略对图像进行训练。我们在三个数据集上进行了大量的实验,在卷积和其他相关的基于变压器的架构上,我们取得了很好的性能。