【图像分割】TransUNet学习笔记

论文名称:TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation
论文地址:https://arxiv.org/pdf/2102.04306.pdf
代码地址:https://github.com/Beckschen/TransUNet


前言:

TransUNet将Transformer和U-Net结合了起来。由于卷积操作本身存在的局限性,U-Net不能很好地建模长距离依赖关系,而Transformer这种全局自注意力机制可以有效地获取全局信息,但其对于低层次细节信息获取不充分,导致其定位能力方面受到限制。所以作者将这两者结合起来,提出了TransUNet网络。同时为了能将高分辨率的特征图通过跳级连接与上采样后的特征图联合以获得充分的信息,作者没有像ViT那样将原图直接打成patch块输入到Transformer模块中,而是先通过CNN进行特征提取得到特征图,再将其变换后输入Transformer编码器模块,最后仿照U-Net解码器逐级上采样并进行跳级连接,最后得到分割结果。【图像分割】TransUNet学习笔记_第1张图片


总体结构:

TransUNet总体上还是一个U型的Encoder-Decoder结构。在编码器部分,将原图输入CNN进行特征提取,线性投影之后进行Patch Embedding将特征图序列化并加上位置编码,输入transformer编码器。在解码器部分,将编码器输出的序列进行reshape然后通过1x1卷积变换通道数之后进行级联上采样,中途通过编码器CNN的各级分辨率特征图进行跳级连接,最后得到分割结果,这部分与U-Net解码器类似。

在论文中作者说道刚开始是直接应用Transformer编码器对原图进行编码,然后将输出的特征图直接上采样到原分辨率,但效果并不是最好的。作者分析输入编码器的\frac{H}{P}\times \frac{H}{P}分辨率对于原图HxW的分辨率来说还是太小了,导致损失了一些低层次的细节信息,比如边界信息。因此作者应用了一个联合CNN-Transformer的结构作为编码器,并在解码器中加入可以获得精确位置信息的级联上采样操作。

作者选用CNN-Transformer这一混合结构设计的原因有两点:1)为了将中间高分辨率的CNN特征图加入到解码器路径中以获得更多的信息以及更精确的位置。2)作者发现使用CNN-Transformer编码器要比但纯的Transformer编码器效果要好。(个人不理解的地方:TransUNet和SETR基本都是受ViT启发,对比了 CNN-Transformer Hybrid 和pure Transformer,为什么TransUNet说混合模型更好而SETR说纯Transformer模型更好呢?)
 


编码器:

ViT+ResNet50

这里的ResNet50与原ResNet50有些不同,首先卷积采用的是StdConv2d而不是传统的Conv2d,然后是用GroupNorm层代替了原来的BatchNorm层,然后我在代码中看到BottleNeck层也变成了PreActivation版本,也就是将ReLU和Normalization层前置了。在原Resnet50网络中,stage1有3个重复堆叠的Block,stage2中是4个,stage3中是6个,stage4中是3个,但在这里的ResNet50中,把stage4中的3个Block移至stage3中,所以stage3中有9个重复堆叠的Block。还有原ResNet50输出的特征图分辨率是从224x224降低到了7x7,输出为原图的1/32,而TransUNet中输出特征图分辨率为14x14,只为原图的1/16,应该是在将stage4中的3个Block拿到stage3中的时候将stride=2改成了stride=1以去掉降采样操作。

【图像分割】TransUNet学习笔记_第2张图片

维度变化:经过Stem,分辨率变为原图1/4,[224, 224, 3]-->[56, 56, 64]。经过Stage1,分辨率不变,[56, 56, 64]-->[56, 56, 256]。经过Stage2,变为原图1/8,[56, 56, 256]-->[28, 28, 512]。经过Stage3,变为原图1/16,[28, 28, 512]-->[14, 14, 1024]。然后通过一个1x1的卷积缩减维度后进行序列化输入Transformer,[14, 14, 1024]-->[14, 14, 768]-->[196, 768],这里的维度就是Transformer需要的序列的维度,也就是论文中由[H, W, 3]变为了[\frac{HW}{P^{2}}, P^{2}\cdot C]。N= \frac{HW}{P^{2}} = 196 即为序列的长度。

最后经过Patch Embedding和位置编码,经过Patch Embedding后,[196, 768] x [768, 768]-->[196, 768],也即[\frac{HW}{P^{2}}, P^{2}\cdot C] x [P^{2}\cdot C, D]-->[\frac{HW}{P^{2}}, D]。

【图像分割】TransUNet学习笔记_第3张图片

这里的E\in R^{(P^{2}\cdot C)\times D}是一个线性投影,目的是将序列映射到[N, D]。

更多Transformer编码器的细节后续再记录。


解码器:

编码器输出的特征图为Z_{L}\in R^{\frac{HW}{P^{2}}\times D},将其进行Reshape,然后应用1x1卷积进行维度缩减。[196, 768]-->[14, 14, 768]-->[14, 14, 512],也即论文图中的[N, D]-->[\frac{H}{P}, \frac{H}{P},D]-->[\frac{H}{P}\frac{H}{P}, 512]。后续操作和U-Net几乎相同。

,  


整体结构和代码对应图,感谢原作者!

【图像分割】TransUNet学习笔记_第4张图片


由于本人水平非常有限,如有错误,恳请指正,欢迎大家一起交流学习!


参考:

(基础)CNN网络结构_Chan_Zeng的博客-CSDN博客

Vision Transformer详解_霹雳吧啦Wz-CSDN博客_wz框架

TransUnet: 结构解析_ripple970227的博客-CSDN博客_unet结构详解

TransUNet_zjiafbaodaozmj的博客-CSDN博客_transunet

你可能感兴趣的:(深度学习,pytorch)