MedT: Medical Transformer 论文阅读 MICCAI2021

MedT: Medical Transformer 论文阅读 MICCAI2021

Tranformer医学图像分割


题目:《Medical Transformer: Gated Axial-Attention for Medical Image Segmentation》
论文: https://arxiv.org/abs/2102.10662.
代码: https://github.com/jeya-maria-jose/Medical-Transformer.
期刊: MICCAI2021


这是一篇发表在MICCAI2021上利用Transformer做医学图像分割的文章。该文主要介绍一下文章的创新动机以及网络结构,便于理解, 实验部分不予以介绍。

一: Motivation

  1. 由于传统的CNN网络中的卷积层缺乏对图像中远程依赖关系的建模能力(即使使用不断的池化层能够提高感受野,但是会引起大量的结构损失),故文章选用了最近大火的Transformer结构来完成对远程依赖关系的建模。
  2. 由于数据量不足是医学图像分割项目的一个很大的局限性,而Transformer却需要大量的数据集进行训练拟合,故文章提出了gated axial attention 结构
  3. 同时为了提高Transformer的性能,文章又提出了局部-全局训练策略(LoGo)

二: Medical Transformer (MedT)

先给出文章中最核心的图,其中图a为对应的MedT结构(使用LoGo训练策略),图b为MedT中编码器使用的Gated Axial Transformer layer, 图c展示的Gated Axial Attention layer为图b中的Gated Multi-head Attn Height& Gated Multi-head Attn Width的主要构成模块。
MedT: Medical Transformer 论文阅读 MICCAI2021_第1张图片
文章在网络模块主要有3个小节(Self-Attention Overview,Gated Axial-Attention ,Local-Global Training)对该图进行说明。

1. Self-Attention Overview

文章先是介绍了self-attention模块,这是最一般的计算公式。qkv都是通过输入x经过不同的W计算而得
请添加图片描述
根据该算式,可以发现所需要的计算量非常大,且没有引入位置信息(文章好像分析的时候默认一般的self attention没有入postion embedding)。

同时本节介绍了由Wang et al.[1]提出的Axial-Attention,可以降低上面提到的计算复杂度。Axial-Attention将原始self-attention分解为两个self-attention模块,第一个模块为在高这个轴上单独计算self-attention,第二个模块为在宽这个轴上单独计算self-attention,称为axial attention (轴向注意力)。同时为了在计算时添加位置信息,作者又添加了一个位置偏置(positional bias),该偏置项常被称为相对位置编码,是可以在训练过程中进行学习的。最终的计算公式为下式(下式为width-axis的公式),其中rq, rk, rv是位置偏置。
请添加图片描述

2. Gated Axial-Attention

文章说到在第一节中介绍了Aial-Attention可以很好的得到非局部的信息,且计算比较高效,但是提到了[1]文章是在大尺度的分割数据中进行计算的,所以能够非常好的学习到kqv中的位置偏差,如果使用低尺度的医学图像进行分割,通常难以学习到准确的位置信息,添加位置偏置反而造成大量误差,故提出Gated Axial- Attention,公式为下式:
请添加图片描述
其中GQ, GK, GV1, GV2是所谓的门控(Gated),是一个学习的参数。Gated的作用就是控制位置信息r的权重,如果位置信息r准确,就给予较大权重,如果不准确,就给予较低的权重。
MedT: Medical Transformer 论文阅读 MICCAI2021_第2张图片

3. Local-Global Training

这部分在结构图种展现在输入的双分支部分。
作者想利用将图片根据Patch的方式进行输入网络的方式来提高transformer的运算速度,但是提到仅仅是patch-wise训练针对医学图像分割并不是一个非常有效的方法,因为Patch-wise的训练方法限制了网络学习不同patch之间的信息相关。为了提高网络对整个图片的信息提取,作者提出了在网络里利用了双分支,一个为global branch处理全局信息, 一个为local branch处理每一个patch信息。其中Global branch减少了部分编解码模块(减少运算时间,且前面几层已经能够提取到全局信了),local branch 则将输入长宽各分了4份,分成16个patch进行运算,然后将出来的所有patch进行reshape。最后将两个branch进行element-wise add后经过一个1*1的卷积层输出。该策略旨在提高性能(此时global branch处理全局信息,local branch处理局部信息,类似多尺度分析),减少运算时间(global branch减少了编解码层,local branch分patch训练)

总结

Transformer真牛逼,快去用来发论文


你可能感兴趣的:(医学图像分割,transformer,深度学习,计算机视觉,神经网络)