描述
Attention被应用于医学图像分割领域以提高神经网络对空间特征及通道的关注度,从而提高分割任务的精确程度。本文根据参考文献实现一种Spatial Attention 和Channel Attention融合使用的注意力层。
代码展示
1.Spatial Attention Layer
class SpatialAttention(Layer):
def __init__(self,outChan,feature_size_high,feature_size_low,feature_dim_high,mode,**kwargs):
super(SpatialAttention,self).__init__()
self.outChan = outChan
self.feature_size_high = feature_size_high
self.feature_size_low = feature_size_low
self.feature_dim_high = feature_dim_high
self.mode=mode
if mode=='2D':
inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1] ,feature_dim_high))
inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1] ,outChan))
Fx_ = Conv2D(feature_dim_high, 2, strides=2,padding = 'same')(inputs_low)
Fy_ = Conv2D(feature_dim_high, 1, padding = 'same')(inputs_high)
M_Spatial = Conv2D(feature_dim_high, 3, activation='relu',padding = 'same')(Fx_+Fy_)
M_Spatial = Conv2D(outChan, 1, activation='sigmoid',padding = 'same')(M_Spatial)
M_Spatial = UpSampling2D(size=(2,2))(M_Spatial)
M_Spatial = Multiply()([M_Spatial,inputs_low])
elif mode=='3D':
inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1] ,feature_size_high[2],feature_dim_high))
inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1] ,feature_size_low[2],outChan))
Fx_ = Conv3D(feature_dim_high, (2,2,1), strides=(2,2,1),padding = 'same')(inputs_low)
Fy_ = Conv3D(feature_dim_high, 1, padding = 'same')(inputs_high)
M_Spatial = Conv3D(feature_dim_high, (3,3,1), activation='relu',padding = 'same')(Fx_+Fy_)
M_Spatial = Conv3D(outChan, 1, activation='sigmoid',padding = 'same')(M_Spatial)
M_Spatial = UpSampling3D(size=(2,2,1))(M_Spatial)
M_Spatial = Multiply()([M_Spatial,inputs_low])
self.SpatialAtten = Model(inputs=[inputs_high,inputs_low],outputs=M_Spatial)
def get_config(self):
config = super().get_config().copy()
config.update(
{
'outChan':self.outChan,
'feature_size_high':self.feature_size_high,
'feature_size_low':self.feature_size_low,
'feature_dim_high':self.feature_dim_high,
'SpatialAtten':self.SpatialAtten,
'mode':self.mode
}
)
return config
def call(self,inputs):
# high level feature:F_y low level feature:F_x
Fy,Fx = inputs
M_Spatial = self.SpatialAtten([Fy,Fx])
return M_Spatial
2.Channel Attention Layer
class ChannelAttention(Layer):
def __init__(self,outChan,feature_size_low,feature_size_high,feature_dim_high,mode,**kwargs):
super(ChannelAttention,self).__init__()
self.outChan = outChan
self.feature_size_high = feature_size_high
self.feature_size_low = feature_size_low
self.feature_dim_high = feature_dim_high
self.mode = mode
if mode=='2D':
inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1] ,feature_dim_high))
inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1] ,outChan))
Fx_ = Conv2D(feature_dim_high, 1, padding = 'same')(inputs_low)
Fy_ = Conv2D(feature_dim_high, 1, padding = 'same')(inputs_high)
Fx_avepool = AveragePooling2D(pool_size=(2, 2))(Fx_) # None*16*16*1024
Fy_avepool = AveragePooling2D(pool_size=(2, 2))(Fy_) # None*8*8*1024
Nx = tf.cast(K.shape(Fx_avepool)[1]*K.shape(Fx_avepool)[2],dtype=tf.float32)
Zx = K.sum(Fx_avepool,axis=2)
Zx = K.sum(Zx,axis=1)/Nx # batchSize*2channels # None*1024
Ny = tf.cast(K.shape(Fy_avepool)[1]*K.shape(Fy_avepool)[2],dtype=tf.float32)
Zy = K.sum(Fy_avepool,axis=2)
Zy = K.sum(Zy,axis=1)/Ny # batchSize*2channels # None*1024
FC1 = Dense(feature_dim_high,activation='relu')(Zx+Zy) # None*1024
FC2 = Dense(outChan,activation='sigmoid')(FC1) # batchSize*channels # None*512
FC2 = K.expand_dims(FC2,axis=1)
FC2 = K.expand_dims(FC2,axis=2) # batchSize*1*1*channels # None*1*1*512
M_chan = K.repeat_elements(FC2,feature_size_low[0],axis=1)
M_chan = K.repeat_elements(M_chan,feature_size_low[1],axis=2)
M_chan = Multiply()([M_chan,inputs_low])
elif mode=='3D':
inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1], feature_size_high[2], feature_dim_high))
inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1], feature_size_low[2], outChan))
Fx_ = Conv3D(feature_dim_high, 1, padding = 'same')(inputs_low)
Fy_ = Conv3D(feature_dim_high, 1, padding = 'same')(inputs_high)
Fx_avepool = AveragePooling3D(pool_size=(2, 2, 1))(Fx_) # None*16*16*1024
Fy_avepool = AveragePooling3D(pool_size=(2, 2, 1))(Fy_) # None*8*8*1024
Nx = tf.cast(K.shape(Fx_avepool)[1]*K.shape(Fx_avepool)[2]*K.shape(Fx_avepool)[3],dtype=tf.float32)
Zx = K.sum(Fx_avepool,axis=3)
Zx = K.sum(Zx,axis=2)
Zx = K.sum(Zx,axis=1)/Nx # batchSize*2channels # None*1024
Ny = tf.cast(K.shape(Fy_avepool)[1]*K.shape(Fy_avepool)[2]*K.shape(Fy_avepool)[3],dtype=tf.float32)
Zy = K.sum(Fy_avepool,axis=3)
Zy = K.sum(Zy,axis=2)
Zy = K.sum(Zy,axis=1)/Ny # batchSize*2channels # None*1024
FC1 = Dense(feature_dim_high,activation='relu')(Zx+Zy) # None*1024
FC2 = Dense(outChan,activation='sigmoid')(FC1) # batchSize*channels # None*512
FC2 = K.expand_dims(FC2,axis=1)
FC2 = K.expand_dims(FC2,axis=2) # batchSize*1*1*channels # None*1*1*512
FC2 = K.expand_dims(FC2,axis=3)
M_chan = K.repeat_elements(FC2,feature_size_low[0],axis=1)
M_chan = K.repeat_elements(M_chan,feature_size_low[1],axis=2)
M_chan = K.repeat_elements(M_chan,feature_size_low[2],axis=3)
M_chan = Multiply()([M_chan,inputs_low])
self.ChanAtten = Model(inputs=[inputs_high,inputs_low],outputs=M_chan)
def get_config(self):
config = super().get_config().copy()
config.update(
{
'outChan':self.outChan,
'feature_size_high':self.feature_size_high,
'feature_size_low':self.feature_size_low,
'feature_dim_high':self.feature_dim_high,
'ChanAtten':self.ChanAtten,
'mode':self.mode
}
)
return config
def call(self,inputs):
# high level feature:F_y low level feature:F_x
Fy,Fx = inputs
M_chan = self.ChanAtten([Fy,Fx])
return M_chan
3.Spatia-lChannel Attention Layer
class ChannelSpatialAttention(Layer):
def __init__(self,outChan,feature_size_low,feature_size_high,feature_dim_high,mode,**kwargs):
super(ChannelSpatialAttention,self).__init__()
self.outChan = outChan
self.feature_size_low = feature_size_low
self.feature_size_high = feature_size_high
self.feature_dim_high = feature_dim_high
self.mode = mode
self.ChannelAttention = ChannelAttention(outChan=outChan,feature_size_low=feature_size_low,feature_size_high=feature_size_high,feature_dim_high=feature_dim_high,mode=mode)
self.SpatialAttention = SpatialAttention(outChan=outChan,feature_size_low=feature_size_low,feature_size_high=feature_size_high,feature_dim_high=feature_dim_high,mode=mode)
def get_config(self):
config = super().get_config().copy()
config.update(
{
'outChan':self.outChan,
'feature_size_low':self.feature_size_low,
'feature_size_high':self.feature_size_high,
'feature_dim_high':self.feature_dim_high,
'ChannelAttention':self.ChannelAttention,
'SpatialAttention':self.SpatialAttention,
'mode':self.mode
}
)
return config
def call(self,inputs):
# high level feature:F_y low level feature:F_x
Fy,Fx = inputs
M_Chan = self.ChannelAttention([Fy,Fx])
M_Spatial = self.SpatialAttention([Fy,M_Chan])
return M_Spatial
https://doi.org/10.1016/j.knosys.2021.106754