注意力机制可以说是深度学习研究领域上的一个热门领域,它在很多模型上都有着不错的表现,比如说BERT模型中的自注意力机制。本博客仅作为本人在看了一些Attention UNet相关文章后所作的笔记,希望能给各位带来一点思考,注意力机制是怎么被应用在医学图像分割的。
参考文章:
UNet是一个用于分割领域的架构,自2015年被提出以来,在医学图像领域取得了不错的表现,成为了不少医疗影像语义分割任务的baseline。感兴趣的可以去看一下这一篇博客:Unet神经网络为什么会在医学图像分割表现好?
UNet的网络结构并不复杂,最主要的特点便是U型结构和skip-connection。而Attention UNet则是使用了标准的UNet的网络架构,并在这基础上整合进去了Attention机制。更准确来说,是将Attention机制整合进了跳远连接(skip-connection)。
整个网络架构如下, 注意力block已用红色框出:
与标准的UNet相比,整体结构是很相似的,唯一不同的是在红框内增加了注意力门。为了公式化这个过程,我们将跳远连接的输入称为x,来自前一个block的输入称为g,那么整个模块就可以用以下公式来表示了:
在这个公式里面,Attention就是注意力门,upsample是一个简单上采样模块,采用最近邻插值,而ConvBlock只是由两个(convolution + batch norm + ReLU)块组成的序列。唯一需要解释的是注意力。
接下来让我们看一下整个注意力门是怎么实现的,整个结构图如下:
整个过程不难理解 ,需要注意一下几点:
下面的代码定义了注意力块(简化版)和用于UNet扩展路径的“up-block”。“down-block”与原UNet一样。
class AttentionBlock(nn.Module):
def __init__(self, in_channels_x, in_channels_g, int_channels):
super(AttentionBlock, self).__init__()
self.Wx = nn.Sequential(nn.Conv2d(in_channels_x, int_channels, kernel_size = 1),
nn.BatchNorm2d(int_channels))
self.Wg = nn.Sequential(nn.Conv2d(in_channels_g, int_channels, kernel_size = 1),
nn.BatchNorm2d(int_channels))
self.psi = nn.Sequential(nn.Conv2d(int_channels, 1, kernel_size = 1),
nn.BatchNorm2d(1),
nn.Sigmoid())
def forward(self, x, g):
# apply the Wx to the skip connection
x1 = self.Wx(x)
# after applying Wg to the input, upsample to the size of the skip connection
g1 = nn.functional.interpolate(self.Wg(g), x1.shape[2:], mode = 'bilinear', align_corners = False)
out = self.psi(nn.ReLU()(x1 + g1))
return out*x
class AttentionUpBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(AttentionUpBlock, self).__init__()
self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2)
self.attention = AttentionBlock(out_channels, in_channels, int(out_channels / 2))
self.conv_bn1 = ConvBatchNorm(in_channels+out_channels, out_channels)
self.conv_bn2 = ConvBatchNorm(out_channels, out_channels)
def forward(self, x, x_skip):
# note : x_skip is the skip connection and x is the input from the previous block
# apply the attention block to the skip connection, using x as context
x_attention = self.attention(x_skip, x)
# upsample x to have th same size as the attention map
x = nn.functional.interpolate(x, x_skip.shape[2:], mode = 'bilinear', align_corners = False)
# stack their channels to feed to both convolution blocks
x = torch.cat((x_attention, x), dim = 1)
x = self.conv_bn1(x)
return self.conv_bn2(x)
整个网络架构完整版实现可以参考 【语义分割系列:七】Attention Unet 论文阅读翻译笔记 医学图像 python实现。