METD-Medical Transformer:用于医学图像分割的门控轴向注意力Transformer

目录

Title:Medical Transformer: Gated Axial-Attention for Medical Image Segmentation

摘要-Abstract

本文的提出动机

方法-Method

网络整体结构图

自注意力机制回顾

Axial Attention(轴向注意力机制)

门控轴向注意力

局部-全局训练策略


Title:Medical Transformer: Gated Axial-Attention for Medical Image Segmentation

摘要-Abstract

卷积架构存在着固有的归纳偏差(归纳偏差指的是神经网络模型会产生具有偏好的预测结果,也就是说归纳偏差会使得学习算法优先考虑具有某些特定属性的解)它们缺乏对图像中长程依赖性的理解。

本文的提出动机

1)传统的CNN的卷积层缺乏对图像中远程依赖关系的建模能力(即使使用不断的池化层能够提高感受野,但是会引起大量的结构损失)。而Transformer在捕获长程依赖关系方面具有良好的性能。

2)由于带标注的医学数据稀缺是一个瓶颈问题,而Transformer结构往往有需要大量的数据才能够取得较好的性能,所以本文提出了Gated Axial Attention结构来考虑解决这个问题。(主要通过在自注意模块中引入额外的控制机制来扩展现有架构

3)另外为了提高Transform的性能,文章提出了局部-全局的训练策略(具体来说,我们对整个图像和各个Patches进行操作以分别学习全局和局部特征)

方法-Method

网络整体结构图

METD-Medical Transformer:用于医学图像分割的门控轴向注意力Transformer_第1张图片

METD-Medical Transformer:用于医学图像分割的门控轴向注意力Transformer_第2张图片

自注意力机制回顾

考虑一个输入特征图x\epsilon R^{^{c^{in}*H*W}},自注意力层的输出y\epsilon R^{c_{out}*H*W}是使用以下等式进行计算的

其中q,k,v都是从x输入的计算投影,与卷积不同自注意力机制能够从整个特征图中捕获非局部信息。但是这种对于相似度的计算的计算量非常之大。

为了克服这种相似性度量计算量非常大的弊端,本文引入了

Axial Attention(轴向注意力机制)

这个轴向注意力机制将原始的自注意力模块分解为两个自注意模块,第一个模块在高这个轴上单独计算自注意力,第二个模块在宽这个轴上单独计算注意力,同时为了在计算时添加位置信息,作者又添加了一个位置偏置,这个偏置通常被称为相对位置编码,是可以在训练过程中进行学习的。对于任何给定的特定的输入特征图x,具有位置编码和宽度轴的更新自注意力机制可以写成:

其中rq, rk, rv是位置偏置。

门控轴向注意力

 上面提到的轴向注意力机制可以很好的得到非局部的信息,并且计算比较高效,但是这种方法是在大尺度的分割数据上进行计算的,所以能够非常好的学习到q,k,v中的位置偏差。如果使用低尺度的医学图像进行分割,通常难以学习到准确的位置信息,添加位置偏差可能会造成大量误差,所以在本小节提出了门控轴向注意力。

其结构示意图如下

METD-Medical Transformer:用于医学图像分割的门控轴向注意力Transformer_第3张图片

具体过程可以总结为:首先输入X通过三个参数矩阵分别得到对应的q,k,v。然后q,k进行矩阵相乘进行相似度的计算。然后q,k分别和位置编码进行矩阵相乘,将这两个个所得的结果再与门控单元进行矩阵乘法,最终将这三个量的结果相加起来得到第一阶段的结果。

将第一阶段的结果送入Softmax进行处理得到一个结果,将这个结果与位置编码和最开始的k值进行矩阵乘法得到两个结果,这两个结果再与门控单元做矩阵乘法得到的两个结果再相加得到门控注意力机制模块的最终输出。

门控Gated的作用就是控制位置信息r的权重如果位置信息r准确,就给予较大的权重,如果不准确则给予较低的权重。

局部-全局训练策略

 这一部分在网络结构图中体现在输入的双分支部分。

作者想利用将图片根据patches的方式进行输入网络的方式来提高Transformer的运行速度。但仅仅是patch-wise训练针对医学图像的分割并不是一个有效的方法。因为这种方法限制了网络学习不同patches之间的信息相关

为了提高网络对于整个图片的信息提取,作者在网络中使用了双分支,一个全局分支处理全局信息,一个局部分支处理patch信息。其中全局分支减少了部分编码器模块,局部分支将输入长宽分为4份,分成了16个patch进行运算。然后将输出的所有patch进行reshape。最后将两个branches进行元素相加后经过一个1*1卷积层输出

这个策略旨在提高性能,减少运算时间。

你可能感兴趣的:(transformer,深度学习,计算机视觉)