Swin-UNet:基于纯Transformerde的医学图像分割网络
近年来CNN已经成为医学图像分析任务的基础结构,尤其是融合了编解码结构和skip-connection的U型网络广泛应用于各种医学图像分析任务。然而受限于卷积操作的局部性,CNN并不能很好的学习全局信息以及长程语义信息。
本文提出的Swin-UNet,是一个纯基于Transformer搭建的U型网络,可以用于医学图像分割任务。
Swin-UNet取标记后的图像patch作为输入,搭建的U型网络包含编解码结构以及skip-connections,这样可以同时学习局部和全局的语义特征。
尤其,Swin-UNet使用了带有移位窗口(shifted window)的层次Swin-Transformer结构作为编码器来提取上下文信息;对称的设置一个Swin-Transformer作为解码器进行patch的上采样,从而恢复特征图的空间分辨率。
本文中网络深度为四级,即下采样4次后再进行4次上采样恢复分辨率,最终用于多器官分割和心室分割任务,实验结果表明Swin-Transformer结果优于全卷积神经网络和卷积-Transformer混合网络。
目前代码已开源。
Paper
Code
得益于深度学习的快速发展,计算机视觉已广泛用于医学图像分析任务。图像分割是医学图像分析领域中的一类中药任务,尤其是精细、鲁棒性强的图像分割算法对计算机辅助诊断(CAD)以及图像辅助临床手术有重要作用。
目前的医学图像分割主要基于全卷积神经网络(FCNN),其中以U型网络结构最为常用。传统的U型网络,如UNet,是一个对称的编码-解码网络结构,并且包含跨层连接(skip connection).Encoder通过一系列卷积和下采样操作提取不同感受野下的语义特征;Decoder通过上采样操作将提取到的特征图恢复到原始分辨率,进行像素级别的密集预测;Skip-connection则负责将同层的encoder提取的信息与decoder进行融合,从而减少因为下采样操作导致的空间信息丢失问题。通过精巧的结构设计,UNet已经在诸多医学图像任务中取得优异的效果,也有一系列UNet变体,如3D UNet,Res-UNet,UNet++,UNet3+等用于不同的医学任务或处理不同模态的医学图像。这些FCNN网络在心室分割、器官分割任务中展示出优异的性能,充分证明了CNN具有较强的特征识别能力。
虽然上述CNN模型已经取得了优异的性能,但仍然无法满足临床应用的精度。鉴于CNN只能较好的提取局部特征,很难同时学习到全局和长程的语义信息。在这方面,一些研究通过使用空洞卷积、自注意力机制、空间金字塔模型等来进行改善,但仍然无法有效的捕捉长程关联。
受Transformer在NLP领域中的启发,目前研究人员整将Transformer迁移到计算机视觉领域。ViT以平民化hip考生们和图像处理成2D Patch序列并嵌入位置信息后送入Transformer,并且在大规模数据集上预训练后迁移至其他视觉任务,取得了与CNN模型相媲美的性能。
除此之外,DeiT还验证了Transformer在中型数据集上训练后,通过知识蒸馏可以使得模型更具鲁棒性。而Swin-Transformer则是在目标检测、图像分类和图像分割任务中均达到了SOTA。以上研究均证明了Transformer应用于视觉任务的巨大潜力。
受Swin Transformer启发,本文提出了Swin-UNet.据我们所知这是第一次搭建纯Transformer的U型网络,其Encoder,Decoder及bottlenec部分均基于Swin Transformer模块。
输入图像被处理成互不重叠的patch,每个patch被视为一个token被送入encoder提取特征;随后提取的feature map再经过decoder的patch expanding layer(patch扩展层)恢复分辨率;同样skip connection用于多尺度的特征融合。
本文在多器官分割和心室分割任务中均取得了优异的性能,同时展现出较好的泛化能力。本文的工作总结如下:
(1)基于Swin Transformer模块,本文搭建了对称的UNeT具有Encoder-Decoder和Skip connection结构,Encoder中self-attention负责学习局部到全局的关联信息,decoder负责将全局特征上采样恢复至原始分辨率进行像素级别的预测;
(2)提出的patch expanding layer负责完成上采样操作,替代之前UNet中的反卷积或插值操作;
(3)实验发现skip connection同样适用于Transformer。
本文最终将基于纯Transformer搭建的U型网络称之为Swin-UNet.
CNN-based methods:
早期的医学图像分割主要基于传统的轮廓检测、机器学习方法。随着CNN的发展,UNet因其简洁的结构、良好的性能成为医学图像分割任务中的基础框架,并出现了一系列变体,如Res-UNet,Dense-UNet,UNet++;还提出了一系列针对3D图像的如3D-UNet,V-Net。
基于CNN强大的特征表述能力,基于CNN的方法在医学图像分割领域取得了巨大的成功。
Vision Transformers:
Transformer最初被用于NLP领域的机器翻译任务;ViT则是首个将Transformer用于视觉任务的先驱工作。与CNN系列方法相比,Transformer的缺点在于需要在大型数据集上预训练之后的结果。
Deit则介绍通过一些训练策略可以让ViT仅在中性数据集上预训练就能达到较好效果;此外基于ViT也在一些任务中达到了优异的性能,比如Swin-Transformer就是一种高效的层次化Transformer结构,其核心是shifted windows滑窗机制,在多种视觉任务上达到了SOTA。
本文则尝试将Swin Transformer模块作为基础结构来搭建UNet用来作为医学图像分割领域的benchmark.
Self-attention/Transformer:
近年来许多研究人员将自注意力机制引入CNN来提升网络的性能,比如提出将注意力门加入到U型网络中,但这还是基于CNN的网络模型;最近也有的工作尝试将CNN与Transformer融合,结合两类的特征提取能力用于各种视觉任务并且取得了不错的成果,如多模态的脑肿瘤分割和3D图像分割任务。
本文则是尝试使用纯Transformer来做医学图像分割任务。
Fig 1展示了Swin-UNet的示意图,包含encoder,decoder,skip connection三部分。
基础模块为Swin Transformer.
Encoder:
首先将输入切分成44大小互不重叠的patch,因此patch大小为44*3;随后度输入patch进行线性映射。
随后映射过的token被送入SwinTransformer模块和patch merge layer用来生成不同尺度的特征表述;Swin Transformer block负责学习特征,patch merge layer就负责下采样操作。
Decoder:
Decoder由多个Swin Transformer 模块和patch expanding layer组成,skip connection依旧负责特征融合,弥补原始信息的丢失,patch expanding layer则进行的是上采样操作,将改层feature map扩充至2x分辨率;最后一个patch expanding layer会扩充至4X分辨率;然后通过线性层 进行像素级别的预测。
Swin Transformer中的self-attention与传统的MSA结构并不一样,Swin Transformer主要基于shifted window.
Fig 2展示了Swin Transformer模块的结构示意图。
每一个Swin Trabsformer模块包含一个LN层、一个MSA模块、使用残差连接、然后在经过一次LN和2层的MLP,激活函数使用的是GELU。
SW-MAS(Shifted Window based Multi-head Self-Attention):Swin Transformer中连续的block会依次交替的使用W-MSA和SW-MAS,SW-MSA和W-MSA的不同之处在于会将window进行shift,这样可以让相邻的window之间产生交互。
Encoder中,输入的token大小为原始分辨率的1/4,即H/4 * W/4 * C的大小,token被送入两个连续的Swin Transformer模块进行特征提取,这一过程中维度保持不变;在patch merge layer负责下采样将token数目下降到输入的1/2,特征维度变厚为输入的2x。 这一过程会在Encoder中重复3次。
Patch merging layer
输入的patch被分成4部分然后通过patch merging layer层被级联,级联后特征维度变厚4倍,特征分辨率变为原来的1/2,再经过线性层将特征维度变为原来的2倍。
Transformer如果网络层次太深会导致无法收敛的问题,因此在Swin UNet的平静不分,仅使用两个Swin Transformer Block进行特征学习和整合,并不改变特征图谱的维度。
会对照Encoder的层次 对称的搭建Decoder部分,依旧基于Swin Transformer 模块,但是将patch merging layer替换为patch expanding layer执行上采样操作,patch expanding会将相邻尺寸的特征图reshape到更高分辨率,响应的将特征维度降到原始特征的一半。
Patch expanding layer
以第一个patch expanding 为例,输入特征维度为 (W/32 * H/32 * 8C),因为Encoder经历了3次patch merging,首先特征维度经过bottleneck后会变成(W/32 * H/32 * 16C);经过reshape操作后特征变薄,分辨率上升,变成(W/16 * H/16 * 4C)。
本文同原始UNet一样,也引入skip connection 以级联的方式来融合多尺度特征。
Synapse multi-organ segmentation dataset数据集
采集了30例病例的3779张临川CT图像,training : testing = 18:12.评价指标为DSC(Dice相似系数)和HD(Hausdorf 距离),对CT图像中的多个器官进行评价(主动脉、胆囊、左右肾脏、肝脏、胰腺、脾、胃部)
Automated cardiac diagnosis challenge dataset :
MRI心室分割数据集,训练:验证:测试=70:10:20,需要分割左心室、右心室和心肌,评价指标:DSC
Python 3.6
Pytorch 1.7.0
数据增强:flip rotations
输入图像:224 * 224
patch size:4
训练卡:V100
使用在ImageNet预训练后的结果
Synapse:
Table 1展示了CT图多器官分割的结果,对比网络有TransUNet,Att-UNet,UNet等。可以看到本文的Swin Unet获得了最佳的DSC和HD指标。
并且与Att-Unet,TransUNet相比,HD指标更低,说明对边缘分割效果更好。
Fig 3展示了部分分割结果,可以看到基于CNN的分割方法往往存在过度分割的问题,可能是由于CNN过度关注局部特征导致的。本文实验表明,通过基于Transformer搭建UNet可以更好的同时学习全局特征和语义信息之间的长程关联,从而提升分割效果。
ACDC:
Table 2展示了在ACDC数据集上的分割结果,可以看到Swin UNet达到了90%的DSC指标,显示出 SwinUNet超强的泛化性能和鲁棒性。
为了进一步探究Swin Unet中各部分的作用,本文进行了消融实验。
Up-sampling:
为了探究patch expanding layer的作用,本文将其与双线性插值、转置卷积进行了对比,Table 3展示了对比结果,可以看到使用patch expanding layer的Swin UNet获得了最佳的性能。
Skip connection:
本文还探究了skip connection的数目对性能的影响,分别在不同层级加入skip connection,发现随着skip connection的增多,性能也获得了提升,因此为了增加模型的鲁棒性,本文在所有层次都引入skip connection,对比实验结果参见Table 4.
Input size:
Table 5展示了不同输入尺寸对性能的影响,patch size依旧保持为4,则不同输入大小影响的是序列的长度;可以看出随着输入尺寸的增大,模型精度略有提高,但是整个网络负载也随之增大。因此综合考虑本文依旧采用224大小的输入。
Model Scale:
本文同样探讨了网络层次对性能的影响,对比结果参见Table 6,可以看到进一步增加网络层次几乎无法提升网络性能,但却使得计算成本显著增加。出于性能-网络效率的综合权衡,本文采用tiny版本的Swin UNet进行医学图像分割任务。
众所周知,Transformer的性能严重依赖于预训练的结果。而本文直接使用在ImageNet上训练后的权重来初始化我们的Swin UNet,这种初始化方案也许不是最优的。因此未来本文还将探索端到端的方法;此外本文的输入图像是二维图像,虽然大多数医学图像是二维的,但本文还将进一步探索Swin UNet在三维医学图像分割中的应用。
本文提出了一种全新的基于纯Transformer搭建的UNet,称之为Swin UNet,其中以Swin Transformer 模块为基本单元用来学习特征和长程语义关联。通过将Swin UNet用于多器官分割和心室分割任务展示了Swin UNet优秀的性能和泛化能力。