Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation
论文:https://arxiv.org/pdf/2105.05537.pdf
代码:https://github.com/HuCaoFighting/Swin-Unet
在过去的几年里,卷积神经网络(CNNs)在医学图像分析中取得了里程碑式的成就。尤其是基于U型结构和跳跃连接的深度神经网络在各种医学图像任务中得到了广泛的应用。然而,尽管CNN取得了很好的性能,但由于卷积运算的局部性,它不能很好地学习全局和远程语义信息交互。在本文中,我们提出了Swin-UNET,它是一种类似于UNET的纯转换器,用于医学图像分割。标记化的图像块被送入基于Transformer的UShape编解码器结构中,并带有跳跃连接,用于局部全局语义特征学习。具体地说,我们使用带移位窗口的分层Swin Transformer作为编码器来提取上下文特征。设计了一种基于对称Swin Transformer的带patch expanding layer的解码器,通过上采样恢复特征地图的空间分辨率。在输入输出直接下采样和上采样4倍的情况下,在多器官和心脏分割任务上的实验表明,基于变压器的纯U型编解码器网络的性能优于全卷积和变换与卷积相结合的方法。代码和经过训练的模型将在https://github.com/HuCaoFighting/Swin-Unet.上公开提供。
得益于深度学习的发展,计算机视觉技术在医学图像分析中得到了广泛的应用。图像分割是医学图像分析的重要组成部分。特别是,准确和稳健的医学图像分割可以在计算机辅助诊断和图像引导的临床手术中发挥基石作用[1,2]。
现有的医学图像分割方法主要依赖于U型结构的全卷积神经网络[3,4,5]。典型的U型网络,U-Net[3],由一个带有跳跃连接的对称编解码器组成。在编码器中,采用一系列卷积层和连续下采样层来提取大感受野的深度特征。然后,解码器将提取的深度特征上采样到输入分辨率进行像素级语义预测,并将来自编码器的不同尺度的高分辨率特征通过跳跃连接进行融合,以缓解下采样造成的空间信息损失。凭借这种优雅的结构设计,U-Net在各种医学成像应用中取得了巨大的成功。根据这一技术路线,已经开发了许多算法,如3DU-NET[6]、RES-UNET[7]、U-NET++[8]和UNET3+[9],用于各种医学成像模式的图像和体积分割。这些基于FCNN的方法在心脏分割、器官分割和病变分割中都表现出很好的性能,证明了CNN具有很强的区分特征学习能力。
目前,基于CNN的分割方法虽然在医学图像分割领域取得了很好的效果,但仍不能完全满足医学应用对分割精度的严格要求。图像分割仍然是医学图像分析中的一项具有挑战性的任务。由于卷积运算的固有局部性,基于CNN的方法很难学习显式的全局和远程语义信息交互[2]。一些研究试图通过使用Arous卷积层[10,11]、自我注意机制[12,13]和图像金字塔[14]来解决这个问题。然而,这些方法在对长期依赖关系进行建模方面仍然存在局限性。最近,受Transformer在自然语言处理(NLP)领域的巨大成功的启发[15],研究人员试图将Transformer带入视觉领域[16]。在[17]中,提出了视觉转换器(ViT)来执行图像识别任务。以位置嵌入的二维图像块为输入,在大数据集上进行预训练,VIT取得了与基于CNN的方法相当的性能。此外,文献[18]中提出了数据高效的图像转换器(DeiT),这表明Transformer可以在中等大小的数据集上进行训练,并与蒸馏方法相结合可以得到更健壮的Transformer。在[19]中,开发了一种分层Swin Transformer。文[19]以Swin Transformer为视觉骨干,在图像分类、目标检测和语义分割等方面取得了一流的性能。ViT、DeiT和Swin Transformer在图像识别任务中的成功展示了Transformer在视觉领域的应用潜力。
受到Swin Transformer[19]成功的激励,我们建议Swin-UNET在这项工作中利用Transformer的力量进行2D医学图像分割。据我们所知,Swin-UNET是第一个纯基于transformer的U型架构,由编码器、bottleneck、解码器和跳跃连接组成。编码器、bottleneck和解码器都是基于Swin Transformer块[19]构建的。输入的医学图像被分割成不重叠的图像块。每个patch都被视为一个令牌,并被馈送到基于Transformer的编码器中,以学习深层的特征表示。提取的上下文特征由带patch expanding layer的解码器进行上采样,并通过跳跃连接与来自编码器的多尺度特征进行融合,从而恢复特征地图的空间分辨率并进一步进行分割预测。在多器官和心脏分割数据集上的大量实验表明,该方法具有良好的分割精度和较强的泛化能力。具体来说,我们的贡献可以概括为:(1)基于Swin Transformer块,构建了一种跳跃连接的对称编解码器结构。在编码器中,实现从局部到全局的自关注;在解码器中,将全局特征向上采样到输入分辨率,以进行相应的像素级分割预测。(2)在不使用卷积和插值运算的情况下,开发了一种patch expanding layer,实现了上采样和特征维数的增加。(3)实验中发现跳跃连接对transformer同样有效,因此最终构建了一种基于transformer的纯U型跳跃连接编解码器体系结构Swin-UNET(Swin-UNET)(简称Swin-UNET)。
CNN-based methods:
早期的医学图像分割方法主要是基于轮廓的和传统的基于机器学习的算法[20,21]。随着深度CNN的发展,文献[3]提出了U-net用于医学图像分割。由于U型结构的简单性和优越性能,各种类似UNET的方法不断涌现,如Res-UNET[7]、Dense-UNET[22]、U-net++[8]和UNet3+[9]。并将其引入3D医学图像分割领域,如3D-UNET[6]和V-Net[23]。目前,基于CNN的方法以其强大的表示能力在医学图像分割领域取得了巨大的成功。
Vision transformers:
transformer在[15]中首次被提出用于机器翻译任务。在NLP领域,基于transformer的方法在各种任务中实现了最先进的性能[24]。在Transformer的成功推动下,研究人员在[17]中引入了一种开创性的视觉转换器(VIT),它在图像识别任务中实现了令人印象深刻的速度和准确性之间的权衡。与基于CNN的方法相比,VIT的缺点是它需要在自己的大数据集上进行预训练。为了减轻培训VIT的难度,DeiT[18]描述了几种使VIT能够在ImageNet上进行良好培训的培训策略。最近,基于VIT[25,26,19]已经做了一些优秀的工作。值得一提的是,在[19]中提出了一种高效有效的分层视觉转换器,称为Swin Transformer,作为视觉主干。基于移位窗口机制,Swin Transformer在包括图像分类、目标检测和语义分割在内的各种视觉任务上实现了最先进的性能。本文试图以Swin Transformer块为基本单元,构建一种跳跃连接的U型编解码器架构,用于医学图像分割,为Transformer在医学图像领域的发展提供一个基准比较。
Self-attention/Transformer to complement CNNs:
近年来,研究人员试图将自我注意机制引入CNN以提高网络性能[13]。在文献[12]中,将附加attention gate的跳跃连接集成到U型结构中进行医学图像分割。不过,这仍然是以CNN为基础的方法。目前,人们正在努力将CNN和Transformer结合起来,以打破CNN在医学图像分割中的主导地位[2,27,1]。在文献[2]中,作者将Transformer和CNN相结合,构成了一种用于二维医学图像分割的强编码器。与[2]、[27]和[28]相似,利用Transformer和CNN的互补性来提高模型的分割能力。目前,Transformer和CNN的各种组合被应用于多模态脑肿瘤分割[29]和3D医学图像分割[1,30]。与上述方法不同,本文尝试探索纯变压器在医学图像分割中的应用潜力。
提出的的Swin-UNET的总体架构如图1所示。Swin-UNET由编码器、bottleneck、解码器和跳跃连接组成。Swin-UNET的基本单元是Swin Transformer block[19]。对于编码器,为了将输入转化为序列嵌入,将医学图像分割成大小为4×4的互不重叠的块,通过这种划分方法,每个块的特征维数变为4×4×3=48(这里为什么乘以3,是针对大脑分割的数据集来说的吗,大脑分割的数据集有3个模态,每个模态都分成大小为4x4的块,一共三个模态,维数共有4x4x3=48个?)。此外,将线性嵌入层应用于将特征维度投影到任意维度(表示为C)。转换后的patch令牌通过几个Swin Transformer block和patch merging layer来生成分层特征表示。具体地说,patch merging layer负责下采样和增维,Swin Transformer block负责特征表示学习。受U-Net[3]的启发,我们设计了一种基于对称变压器的解码器。解码器由Swin Transformer block 和patch expanding layer组成。提取的上下文特征通过跳跃连接与来自编码器的多尺度特征融合,以弥补下采样造成的空间信息损失。与patch merging layer不同的是,专门设计了patch expanding layer来执行上采样。patch expanding layer将相邻维度的特征图重塑为分辨率为2倍上采样的大型特征图。最后,利用最后一层patch expanding layer进行4倍上采样,将特征图的分辨率恢复到输入分辨率(W×H),然后在这些上采样的特征上应用线性投影层输出像素级分割预测。我们将在下面详细说明每个模块。
与传统的多头自关注(MSA)模块不同,Swin-Transformer block[19]是基于移位窗口构造的。在图2中,提出了两个连续的swin transformer block。每个swin transformer block由LN层、多头自关注模块、残差连接和具有Gelu非线性的二层MLP组成。基于窗口的多头自关注(W-MSA)模块和基于移位窗口的多头自关注(SW-MSA)模块分别应用于连续的两个变压器块。基于这种窗口划分机制,连续的swin transformer block可以表示为:
其中ˆZl和 Z1分别表示第l块的(S)W-MSA模块和MLP模块的输出。与前面的工作[31,32]类似,自我注意的计算公式如下:
其中Q,K,V∈Rm2×d表示查询、键和值矩阵。M2和d分别表示窗口中的patch数量和Q或K的维度。并且,B中的值取自偏置矩阵ˆB∈R*(2m−1)×(2m+1)^。
在编码器中,将分辨率为H/4×W/4的C维标记化输入送入两个连续的Swin Transform block进行表示学习,其中特征尺寸和分辨率保持不变。同时,patch expanding layer将减少表征数(2倍下采样),并将特征维数提高到原来的2倍。此过程将在编码器中重复三次(dimension和resolution的变化是在patch merging layer完成的,2个Swin Transformer block+patch merging layer是一组,一共重复3次)。
输入的patch被分成4个部分,并通过patch merging layer连接在一起。经过这样的处理,特征分辨率将降低2倍。并且,由于拼接操作导致特征维数增加了4倍,因此在拼接的特征上应用线性层以将特征维数统一到2倍的原始维数(patch merging layer后(或中)有一个linear layer,作用:将dimension的4改为2)。
由于transformer太深而无法收敛[33],因此只使用两个连续的Swin Transformer block来构建学习深层特征表示的bottleneck。在bottleneck中,特征尺寸和分辨率保持不变。
与编码器相对应,对称解码器是基于Swin Transformer块构建的。为此,与编码器中使用的patch merging layer不同,我们在解码器中使用patch expanding layer对提取的深度特征进行上采样。patch expanding layer将相邻维度的特征图重塑为更高分辨率的特征图(2倍上采样),并相应地将特征维数降至原始维数的一半。
Patch expanding layer
以第一个patch expanding layer为例,在上采样之前,对输入特征(W /32×H/32×8C)应用linear layer,将特征维数增加到原来的2倍(W/32×H/32×16C)。然后,使用rearrange operation将输入特征的分辨率扩展到输入分辨率的2倍,并将特征维数降低到输入维数的四分之一(W/32×H/32×16C→W/16×H/16×4C)。我们将在4.5节讨论使用patch expanding layer执行上采样的影响。
与U-Net[3]类似,跳跃连接用于将编码器中的多尺度特征与上采样特征融合。我们将浅层特征和深层特征拼接在一起,以减少下采样造成的空间信息损失。在linear layer之后,串联特征的维度保持与上采样特征的维度相同。在第4.5节中,我们将详细讨论跳跃连接的数量对模型性能的影响。
Synapse multi-organ segmentation dataset (Synapse)
数据集包括30例3779幅腹部轴位临床CT图像。在文献[2,34]的基础上,将18个样本划分为训练集,将12个样本划分为测试集。并以平均Dice-相似性系数(DSC)和平均Hausdorff距离(HD)作为评价指标,对8个腹部器官(主动脉、胆囊、脾、左肾、右肾、肝、胰、脾、胃)进行评价。
Automated cardiac diagnosis challenge dataset (ACDC)
ACDC数据集是使用MRI扫描仪从不同患者那里收集的。对于每个患者的MR图像,标记左心室(LV)、右心室(RV)和心肌(MYO)。数据集分为70个训练样本、10个验证样本和20个测试样本。与[2]类似,在此数据集上只使用平均DSC来评估我们的方法。
Swin-UNET是基于Python3.6和Pytorch 1.7.0实现的。对于所有训练案例,使用诸如翻转和旋转等数据扩充来增加数据多样性。输入图像大小和patch大小分别设置为224×224和4。我们在32 GB内存的NVIDIA V100 GPU上训练我们的模型。使用在ImageNet上预先训练的权值来初始化模型参数。在训练期间,批大小为24,并使用目前流行的动量为0.9、权值衰减为1e-4的SGD优化器对模型进行反向传播优化。
在Synapse多器官CT数据集上,所提出的Swin-UNET与以前最先进的方法的比较如表1所示。与TransUnet[2]不同,我们在Synapse数据集上添加了我们自己实现的U-Net[3]和Att-UNET[37]的测试结果。实验结果表明,我们提出的类UNET纯transformer方法分割效果最好,分割正确率分别为79.13%(dsc↑)和21.55%(hd↓)。与ATT-UNET[37]和最近的方法TransUnet[2]相比,虽然我们的算法在DSC评价指标上没有太大的改进,但是我们在HD评价指标上获得了大约4%和10%的准确率提高,这表明我们的方法可以获得更好的边缘预测。不同方法在Synapse多器官CT数据集上的分割结果如图3所示。从图中可以看出,基于CNN的方法往往存在过度分割问题,这可能是由于卷积运算的局部性造成的。在这项工作中,我们证明了通过将Transformer与带有跳跃连接的U型架构相结合,没有卷积的纯Transformer方法可以更好地学习全局和远程语义信息交互,从而产生更好的分割结果。
与Synapse数据集类似,提出的Swin-UNET算法在ACDC数据集上进行训练,以执行医学图像分割。实验结果汇总在表2中。利用MR模式的图像数据作为输入,SwinUnet仍能达到90.00%的准确率,具有较好的泛化能力和鲁棒性。
为了探索不同因素对模型性能的影响,我们在Synapse数据集上进行了消融研究。具体而言,下文将讨论上采样、跳跃连接的数量、输入大小和模型大小。
Effect of up-sampling
与编码器中的patch merging layer相对应,我们在解码器中专门设计了一个patch expanding layer来进行上采样和特征维数的增加。为了探索提出的patch expanding layer的有效性,我们在Synapse数据集上进行了双线性插值、转置卷积和patch expanding layer的Swin-UNET实验。表3中的实验结果表明,Swin-UNET结合patch expanding layer可以获得更好的分割精度。
Effect of the number of skip connections
我们的SwinUNet的跳跃连接被添加到1/4、1/8和1/16分辨率大小的位置。通过将跳跃连接的数量分别改变为0、1、2和3(这个skip connection 的数量是什么意思?怎么改?0是没有跳跃连接,3是三个位置都有跳跃连接,那么1和2分别是哪一个和哪两个呢?位置不确定),考察了不同跳跃连接对模型分割性能的影响。在表4中,我们可以看到模型的分割性能随着跳跃连接数量的增加而提高。因此,为了提高模型的健壮性,本文将跳跃连接的数量设置为3。
Effect of input size
表5中给出了输入分辨率为224×224、384×384的Swin-UNET的测试结果。当输入大小从224×224增加到384×384,且patch大小保持不变时,transformer的输入令牌序列将变得更大,从而提高模型的分割性能。然而,虽然模型的分割精度略有提高,但整个网络的计算量也明显增加。为了保证算法的运行效率,本文以224×224分辨率尺度作为输入进行实验。
Effect of model scale
类似于[19],我们讨论了网络深化对模型性能的影响。从表2中可以看出。模型规模的增大几乎没有提高模型的性能,反而增加了整个网络的计算成本。考虑到精度和速度的权衡,我们采用了基于Tiny的模型进行医学图像分割。
众所周知,Transformer模型的性能受到模型预训练的严重影响。在这项工作中,我们直接使用Swin Transformer[19]在ImageNet上的训练权重来初始化网络编码器和解码器,这可能是一种次优方案。这种初始化方法很简单,在未来我们将探索对Transformer进行端到端预训练的方法,用于医学图像分割。此外,由于本文的输入图像是二维的,而医学图像数据大多是三维的,因此我们将在后续的研究中探索Swin-UNET在三维医学图像分割中的应用。
本文介绍了一种用于医学图像分割的基于纯变压器的U形编解码器。为了充分利用Transformer的强大功能,我们将Swin Transformer块作为特征表示和远程语义信息交互学习的基本单元。在多器官和心脏分割任务上的大量实验表明,Swin-UNET具有良好的性能和泛化能力。