之前看过一篇dual attention做自然图像分割的文章[1],后来看到作者还出了个优化版,叫Dual Relation-Aware Attention[2],主要解决的问题是dual attention计算和存储成本过高的问题(我跑dual attention也是一直OOM)。顺便试着实现了一下,但是不保证准确性,欢迎讨论,指出错误^_=
[1]Dual Attention Network for Scene Segmentation
论文:https://arxiv.org/abs/1809.02983
[2]Scene Segmentation With Dual Relation-Aware Attention Network
论文:https://ieeexplore.ieee.org/abstract/document/9154612/
代码地址(Pytorch):https://github.com/junfu1115/DANet (作者把两个论文的代码放在一个项目里了)
Dual Attention的位置注意力PAM是通过特征图的内积实现的(实际上就是Self Attention,通过矩阵点乘建模像素间的全局关系),但当特征图比较大的时候,需要高昂的GPU内存开销和计算成本,因此作者提出了紧凑型位置注意力模块CPAM,其实现是通过金字塔池化(不同大小的池化)构建了每个像素和几个聚集中心之间的关系,将这些池化特征拼接起来做self attention内积,一定程度上减少了运算量和内存消耗。
作者注意到,使用通道注意力(CAM)模块时,如果特征映射数(通道数)较大时,需要注意计算量的问题。为了解决这个问题,作者提出了一个紧凑型通道注意力模块CCAM来建立每个通道图和通道聚集中心之间的关系。主要实现是先对输入特征通过1x1卷积进行降维, 再对其计算self attention。
可以看出,降维之后的dual attention对算力和内存的消耗还是少了很多的(在我自己的网络上体现为终于不会OOM了,不过我做小样本医学图像分割,也没观察到涨点)。CAM和CCAM在涨点上差不多就不放了,值得注意的是PAM和CPAM,虽然两者涨点相同,但是作者发现PAM对小目标的提升更大,CPAM对大目标的提升更大(表3),感觉是因为池化导致像素间的关系缺失,导致小目标的识别比较受限。
另外,作者比较了CPAM和CCAM的连接方式,最后发现对输入特征分别并行计算CPAM和CCAM,再拼接最有效。
文章还提出了一个Details of cross-level gating decoder,不过还没仔细研究。
class CPAMEnc(Module):
"""
CPAM encoding module
"""
def __init__(self, in_channels, norm_layer):
super(CPAMEnc, self).__init__()
self.pool1 = AdaptiveAvgPool2d(1)
self.pool2 = AdaptiveAvgPool2d(2)
self.pool3 = AdaptiveAvgPool2d(3)
self.pool4 = AdaptiveAvgPool2d(6)
self.conv1 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),
norm_layer(in_channels),
ReLU(True))
self.conv2 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),
norm_layer(in_channels),
ReLU(True))
self.conv3 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),
norm_layer(in_channels),
ReLU(True))
self.conv4 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),
norm_layer(in_channels),
ReLU(True))
def forward(self, x):
b, c, h, w = x.size()
feat1 = self.conv1(self.pool1(x)).view(b,c,-1)
feat2 = self.conv2(self.pool2(x)).view(b,c,-1)
feat3 = self.conv3(self.pool3(x)).view(b,c,-1)
feat4 = self.conv4(self.pool4(x)).view(b,c,-1)
return torch.cat((feat1, feat2, feat3, feat4), 2)
class CPAMDec(Module):
"""
CPAM decoding module
"""
def __init__(self,in_channels):
super(CPAMDec,self).__init__()
self.softmax = Softmax(dim=-1)
self.scale = Parameter(torch.zeros(1))
self.conv_query = Conv2d(in_channels = in_channels , out_channels = in_channels//4, kernel_size= 1) # query_conv2
self.conv_key = Linear(in_channels, in_channels//4) # key_conv2
self.conv_value = Linear(in_channels, in_channels) # value2
def forward(self, x,y):
"""
inputs :
x : input feature(N,C,H,W) y:gathering centers(N,K,M)
returns :
out : compact position attention feature
attention map: (H*W)*M
"""
m_batchsize,C,width ,height = x.size()
m_batchsize,K,M = y.size()
proj_query = self.conv_query(x).view(m_batchsize,-1,width*height).permute(0,2,1)#BxNxd
proj_key = self.conv_key(y).view(m_batchsize,K,-1).permute(0,2,1)#BxdxK
energy = torch.bmm(proj_query,proj_key)#BxNxK
attention = self.softmax(energy) #BxNxk
proj_value = self.conv_value(y).permute(0,2,1) #BxCxK
out = torch.bmm(proj_value,attention.permute(0,2,1))#BxCxN
out = out.view(m_batchsize,C,width,height)
out = self.scale*out + x
return out
class CCAMDec(Module):
"""
CCAM decoding module
"""
def __init__(self):
super(CCAMDec,self).__init__()
self.softmax = Softmax(dim=-1)
self.scale = Parameter(torch.zeros(1))
def forward(self, x,y):
"""
inputs :
x : input feature(N,C,H,W) y:gathering centers(N,K,H,W)
returns :
out : compact channel attention feature
attention map: K*C
"""
m_batchsize,C,width ,height = x.size()
x_reshape =x.view(m_batchsize,C,-1)
B,K,W,H = y.size()
y_reshape =y.view(B,K,-1)
proj_query = x_reshape #BXC1XN
proj_key = y_reshape.permute(0,2,1) #BX(N)XC
energy = torch.bmm(proj_query,proj_key) #BXC1XC
energy_new = torch.max(energy,-1,keepdim=True)[0].expand_as(energy)-energy
attention = self.softmax(energy_new)
proj_value = y.view(B,K,-1) #BCN
out = torch.bmm(attention,proj_value) #BC1N
out = out.view(m_batchsize,C,width ,height)
out = x + self.scale*out
return out
本人使用的tensorflow版本为2.8.0,建议使用2.5.0及以上版本
为了使用自适应平均池化(AdaptiveAveragePooling),需要安装tensorflow_addons(当然也可以自己计算池化尺度,然后用普通池化)
# DRANet
import tensorflow as tf
import numpy as np
from keras.layers import *
import tensorflow_addons as tfa
def conv_norm_act(input_tensor, filters, kernel_size , dilation=1, norm_type='batch', act_type='relu'):
'''
Conv2d + Normalization(norm_type:str) + Activation(act_type:str)
'''
output_tensor = Conv2D(filters, kernel_size, padding='same', dilation_rate=(dilation, dilation), use_bias=False if norm_type is not None else True, kernel_initializer='he_normal')(input_tensor)
output_tensor = normalization(output_tensor, normalization=norm_type)
output_tensor = Activation(act_type)(output_tensor)
return output_tensor
# 仅支持channel last
def cpam_enc(x):
'''x: input tensor with shape [B, H, W, C]'''
b, h, w, c = x.shape
# x = tf.transpose(x, [0, 3, 1, 2]) # must be channel last
feat1 = tfa.layers.AdaptiveAveragePooling2D(output_size=(1, 1))(x)
feat2 = tfa.layers.AdaptiveAveragePooling2D(output_size=(2, 2))(x)
feat3 = tfa.layers.AdaptiveAveragePooling2D(output_size=(3, 3))(x)
feat4 = tfa.layers.AdaptiveAveragePooling2D(output_size=(6, 6))(x)
feat1 = tf.reshape(tf.transpose(conv_norm_act(feat1, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 1))
feat2 = tf.reshape(tf.transpose(conv_norm_act(feat2, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 4))
feat3 = tf.reshape(tf.transpose(conv_norm_act(feat3, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 9))
feat4 = tf.reshape(tf.transpose(conv_norm_act(feat4, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 36))
return concatenate([feat1, feat2, feat3, feat4], 2)
def cpam_dec(x, y):
'''
inputs :
x : input feature(N,H,W,C) y:gathering centers(N,K,M)
returns :
out : compact position attention feature
attention map: (H*W)*M
'''
b, h, w, c = x.shape
b, k, m = y.shape
# scale = tf.Variable(tf.zeros(1))
scale = tf.Variable(tf.ones(1))
proj_query = Conv2D(c//4, 1)(x)
proj_query = tf.transpose(proj_query, [0, 3, 1, 2])
proj_query = tf.transpose(tf.reshape(proj_query, (-1, c//4, h*w)), [0, 2, 1])
proj_key = Dense(c//4)(y)
proj_key = tf.transpose(tf.reshape(proj_key, (-1, k, c//4)), [0, 2, 1])
energy = tf.matmul(proj_query, proj_key)
attention = tf.nn.softmax(energy)
proj_value = tf.transpose(Dense(c)(y), [0, 2, 1])
out = tf.matmul(proj_value, tf.transpose(attention, [0, 2, 1]))
out = tf.reshape(out, (-1, c, h, w))
out = tf.transpose(out, [0, 2, 3, 1])
out = out * scale + x
return out
def ccam_enc(x):
b, h, w, c = x.shape
x = conv_norm_act(x, c//8, 1, 'batch', 'relu')
x = tf.transpose(x, [0, 3, 1, 2])
return x
def ccam_dec(x, y):
'''
inputs:
x : input feature(N,H,W,C), y:gathering centers(N,K,H,W)
returns :
out : compact channel attention feature
attention map: K*C
'''
m_batchsize, height, width, c = x.shape
x = tf.transpose(x, [0, 3, 1, 2]) # must be channel last
x_reshape = tf.reshape(x, (-1, c, height*width))
# scale = tf.Variable(tf.zeros(1))
scale = tf.Variable(tf.ones(1))
b, k, h, w = y.shape
y_reshape = tf.reshape(y, (-1, k, h*w))
proj_query = x_reshape
porj_key = tf.transpose(y_reshape, [0, 2, 1])
energy = tf.matmul(proj_query, porj_key)
energy_new = tf.reduce_max(energy, -1, keepdims=True)
energy_new = tf.repeat(energy_new, energy.shape[-1], -1)
energy_new = energy_new - energy
attention = tf.nn.softmax(energy_new)
proj_value = tf.reshape(y, (-1, k, h*w))
out = tf.matmul(attention, proj_value)
out = tf.reshape(out, (-1, c, height, width))
out = x + scale * out
out = tf.transpose(out, [0, 2, 3, 1])
return out
def dra_attention(x, filters):
y1 = cpam_enc(x)
y2 = ccam_enc(x)
att1 = cpam_dec(x, y1)
att2 = ccam_dec(x, y2)
att = concatenate([att1, att2], -1) # channel last
att = conv_norm_act(att, filters, 1, 'batch', 'relu')
return att
欢迎纠错,指出问题TAT