本文是 2021 年发表在 MICCAI 上的一篇文章,在当年的多项医学图像分割任务挑战中都获得了不错的成绩。本文主要介绍了这篇论文提出的 Medical Transformer 的结构及相关内容。
原论文链接:Medical Transformer: Gated Axial-Attention for Medical Image Segmentation
卷积架构存在着固有的归纳偏差(归纳偏差指的是神经网络模型会产生具有偏好的预测结果,也就是说归纳偏差会使得学习算法优先考虑具有某些特定属性的解),它们缺乏对图像中长程依赖性的理解。
文章提出了用 transformer 来做医学图像分割。要解决的问题是,transformer 在图像任务上相比卷积神经网络需要更大的数据集来训练,而医学图像处理的一个难题就是数据不足,数据集不够大。
本文主要的贡献:
在 ConvNets 中,每个卷积核只关注整个图像中像素的局部子集,并迫使网络关注局部模式而不是全局上下文。虽然后续提出了一些弥补的 trick 如图像金字塔、Atrus 卷积和注意机制等,仍然无法完全解决这个问题。
由于图像的背景是分散的,学习对应于背景的像素之间的 long-range dependencies 可以帮助网络防止将一个像素错误地归类为掩码,从而减少假阳性(将 0 视为背景,1 视为分割掩码)。同样,当分割遮罩很大时,学习遮罩对应的像素之间的长距离依赖关系也有助于进行有效预测。
数字图像处理中,分割掩码主要用于:
本文动机:
MedT 有两个分支结构,一个全局分支结构和一个局部分支结构,这两个分支的输入是从初始 conv 块提取的特征图。该块有 3 个 conv 层,每个 conv 层后面都有 batch normalization
和 ReLU
激活函数。
网络整体结构如下图所示,为两分支的 U-shape 结构,结构中的 Encoder 与 Decoder:
Transformer
层
Transformer
机制在 U-Net
结构的 Encoder 部分的 self-attention
机制上,并且不像其它 Transformer
用于 cv 的方法一样依赖于大数据集预训练的权重,本方法不需要预训练batch normalization
)和两层 multi-head attention block
,其中一层沿高度轴操作,另一层沿宽轴操作,每个 multi-head attention block
由提出的门控轴向注意层组成。
multi-head attention block
具有 8 8 8 个门控的轴向 multi-head
multi-head attention block
的输出通过另一个 1 × 1 1 \times 1 1×1 卷积层被添加到残差输入图中以产生输出注意图conv
块
ReLU
激活函数skip connections
具有高度 H H H、权重 W W W 和通道 $C_{in} $的输入特征映射 x ∈ R C i n × H × W x \in R^{C_{in} \times H \times W} x∈RCin×H×W借助投影输入,使用以下公式计算自注意力层的输出 y ∈ R C o u t × H × W y \in R^{C_{out} \times H \times W} y∈RCout×H×W:
参数含义:
自注意力机制的局限:
token
会对其他所有的每个 token
都计算注意力,所以是 ( h w ) 2 (hw)^2 (hw)2 次计算,这是非常庞大的计算量。有关 ViT 的其他内容可以参考我的另一篇 blog:CV-Model【6】:Vision Transformerposition embedding
,就是用一个 onehot
的位置向量,经过一个全连接的 embedding
,产生位置编码,这个全连接是可训练的
加上轴向注意力和多个位置编码的 trick 后,注意力机制如下所示(文章中给出的是宽度方向 w w w 上的注意里,高度方向 h h h 上的注意力类似):
参数含义:
高度方向同理
然而上述的 trick 需要大量数据集进行训练,小量的数据不足以训练 QKV 的三个 position embedding
,而医学数据集多数情况下就是少量的
在这种情况下,不准确的 position embedding
会给网络准确率带来负面影响,为此文章提出了个方法用来控制这个影响的程度,修改上述公式如下:
这里三个 G Q , G K , G V G_Q, G_K, G_V GQ,GK,GV 都是可学习的参数,当数据集不足以使得网络预测准确的 position embedding
时,网络的 G
会小一点,反之会大一点,因此起到一个所谓的 Gated 的作用。position embedding
只要相对位置一样,对不同的样本应该是一样的,因为 position embedding
只是位置信息,没有包含语义信息
通常,如果一个相对位置编码被准确学习,相对于那些没有被准确学习的编码,门控机制会赋予它较高的权重。
Transformer 做图像分割可以用 patch-wise
的方式去做,也就是说把一张完整图片切割成多个 patch
,每个 patch
和这个 patch
对应的 mask
作为一个样本,用来训练 transformer,这样十分快
然而问题在于,一张图片的一个病灶可能比一个 patch
大,这样的话这个 patch
看起来就会很奇怪,因为被病灶充满了。这限制了网络学习 patch
间像素的任何信息或依赖性
Local-Global 这个部分的思路有点像多尺度的一个思考,他将网络分成了两个分支(branch):
patch
,每个 patch
单独送 transformer block 前向传播,patch
和 patch
之间没有任何联系,最后再把这 4 × 4 4 \times 4 4×4 个 patch
的 feature map
通过 concat
操作拼接在一起
patch
通过网络进行前向反馈,并根据其位置对输出特征图进行重新采样,以获得输出特征图将两个分支的输出特征图相加并通过 1 × 1 1 \times 1 1×1 卷积层以产生输出分割掩码。这种在图像的全局上下文上操作较浅的模型和在 patch
上操作较深的模型的策略提高了性能,因为全局分支关注的是高层信息,而局部分支关注的是更精细的细节
MedT 使用预测和 ground truth 之间的二进制交叉熵 (CE) 损失来训练网络
参数含义:
作者探索使用基于 transformer 的编码器架构来分割医学图像,而不需要任何预训练: