基于Keras的Channel-Spatial Attention Layers的实现

基于Keras的Channel-Spatial Attention Layers的实现

描述
Attention被应用于医学图像分割领域以提高神经网络对空间特征及通道的关注度,从而提高分割任务的精确程度。本文根据参考文献实现一种Spatial Attention 和Channel Attention融合使用的注意力层。
基于Keras的Channel-Spatial Attention Layers的实现_第1张图片

代码展示
1.Spatial Attention Layer

class SpatialAttention(Layer):
    def __init__(self,outChan,feature_size_high,feature_size_low,feature_dim_high,mode,**kwargs):
        super(SpatialAttention,self).__init__()
        self.outChan = outChan
        self.feature_size_high = feature_size_high
        self.feature_size_low = feature_size_low
        self.feature_dim_high = feature_dim_high
        self.mode=mode
        if mode=='2D':
            inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1] ,feature_dim_high))
            inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1] ,outChan))
            Fx_ = Conv2D(feature_dim_high, 2, strides=2,padding = 'same')(inputs_low)
            Fy_ = Conv2D(feature_dim_high, 1, padding = 'same')(inputs_high)
            M_Spatial = Conv2D(feature_dim_high, 3, activation='relu',padding = 'same')(Fx_+Fy_)
            M_Spatial = Conv2D(outChan, 1, activation='sigmoid',padding = 'same')(M_Spatial)
            M_Spatial = UpSampling2D(size=(2,2))(M_Spatial)
            M_Spatial = Multiply()([M_Spatial,inputs_low])
        elif mode=='3D':
            inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1] ,feature_size_high[2],feature_dim_high))
            inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1] ,feature_size_low[2],outChan))
            Fx_ = Conv3D(feature_dim_high, (2,2,1), strides=(2,2,1),padding = 'same')(inputs_low)
            Fy_ = Conv3D(feature_dim_high, 1, padding = 'same')(inputs_high)
            M_Spatial = Conv3D(feature_dim_high, (3,3,1), activation='relu',padding = 'same')(Fx_+Fy_)
            M_Spatial = Conv3D(outChan, 1, activation='sigmoid',padding = 'same')(M_Spatial)
            M_Spatial = UpSampling3D(size=(2,2,1))(M_Spatial)
            M_Spatial = Multiply()([M_Spatial,inputs_low])
        self.SpatialAtten = Model(inputs=[inputs_high,inputs_low],outputs=M_Spatial)
        
    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                'outChan':self.outChan,
                'feature_size_high':self.feature_size_high,
                'feature_size_low':self.feature_size_low,
                'feature_dim_high':self.feature_dim_high,
                'SpatialAtten':self.SpatialAtten,
                'mode':self.mode
            }
        )
        return config
    
    def call(self,inputs):
        # high level feature:F_y  low level feature:F_x
        Fy,Fx = inputs
        M_Spatial = self.SpatialAtten([Fy,Fx])
        return M_Spatial

2.Channel Attention Layer

class ChannelAttention(Layer):
    def __init__(self,outChan,feature_size_low,feature_size_high,feature_dim_high,mode,**kwargs):
        super(ChannelAttention,self).__init__()
        self.outChan = outChan
        self.feature_size_high = feature_size_high
        self.feature_size_low = feature_size_low
        self.feature_dim_high = feature_dim_high
        self.mode = mode
        if mode=='2D':
            inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1] ,feature_dim_high))
            inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1] ,outChan))
            Fx_ = Conv2D(feature_dim_high, 1, padding = 'same')(inputs_low)
            Fy_ = Conv2D(feature_dim_high, 1, padding = 'same')(inputs_high)
            Fx_avepool = AveragePooling2D(pool_size=(2, 2))(Fx_) # None*16*16*1024
            Fy_avepool = AveragePooling2D(pool_size=(2, 2))(Fy_) # None*8*8*1024
            Nx = tf.cast(K.shape(Fx_avepool)[1]*K.shape(Fx_avepool)[2],dtype=tf.float32)
            Zx = K.sum(Fx_avepool,axis=2)
            Zx = K.sum(Zx,axis=1)/Nx # batchSize*2channels  # None*1024
            Ny = tf.cast(K.shape(Fy_avepool)[1]*K.shape(Fy_avepool)[2],dtype=tf.float32)
            Zy = K.sum(Fy_avepool,axis=2)
            Zy = K.sum(Zy,axis=1)/Ny # batchSize*2channels  # None*1024
            FC1 = Dense(feature_dim_high,activation='relu')(Zx+Zy) # None*1024
            FC2 = Dense(outChan,activation='sigmoid')(FC1) # batchSize*channels # None*512
            FC2 = K.expand_dims(FC2,axis=1)
            FC2 = K.expand_dims(FC2,axis=2) # batchSize*1*1*channels # None*1*1*512
            M_chan = K.repeat_elements(FC2,feature_size_low[0],axis=1)
            M_chan = K.repeat_elements(M_chan,feature_size_low[1],axis=2)
            M_chan = Multiply()([M_chan,inputs_low])
        elif mode=='3D':
            inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1], feature_size_high[2], feature_dim_high))
            inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1], feature_size_low[2], outChan))
            Fx_ = Conv3D(feature_dim_high, 1, padding = 'same')(inputs_low)
            Fy_ = Conv3D(feature_dim_high, 1, padding = 'same')(inputs_high)
            Fx_avepool = AveragePooling3D(pool_size=(2, 2, 1))(Fx_) # None*16*16*1024
            Fy_avepool = AveragePooling3D(pool_size=(2, 2, 1))(Fy_) # None*8*8*1024
            Nx = tf.cast(K.shape(Fx_avepool)[1]*K.shape(Fx_avepool)[2]*K.shape(Fx_avepool)[3],dtype=tf.float32)
            Zx = K.sum(Fx_avepool,axis=3)
            Zx = K.sum(Zx,axis=2)
            Zx = K.sum(Zx,axis=1)/Nx # batchSize*2channels  # None*1024
            Ny = tf.cast(K.shape(Fy_avepool)[1]*K.shape(Fy_avepool)[2]*K.shape(Fy_avepool)[3],dtype=tf.float32)
            Zy = K.sum(Fy_avepool,axis=3)
            Zy = K.sum(Zy,axis=2)
            Zy = K.sum(Zy,axis=1)/Ny # batchSize*2channels  # None*1024
            FC1 = Dense(feature_dim_high,activation='relu')(Zx+Zy) # None*1024
            FC2 = Dense(outChan,activation='sigmoid')(FC1) # batchSize*channels # None*512
            FC2 = K.expand_dims(FC2,axis=1)
            FC2 = K.expand_dims(FC2,axis=2) # batchSize*1*1*channels # None*1*1*512
            FC2 = K.expand_dims(FC2,axis=3)
            M_chan = K.repeat_elements(FC2,feature_size_low[0],axis=1)
            M_chan = K.repeat_elements(M_chan,feature_size_low[1],axis=2)
            M_chan = K.repeat_elements(M_chan,feature_size_low[2],axis=3)
            M_chan = Multiply()([M_chan,inputs_low])
        self.ChanAtten = Model(inputs=[inputs_high,inputs_low],outputs=M_chan)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                'outChan':self.outChan,
                'feature_size_high':self.feature_size_high,
                'feature_size_low':self.feature_size_low,
                'feature_dim_high':self.feature_dim_high,
                'ChanAtten':self.ChanAtten,
                'mode':self.mode
            }
        )
        return config
    
    def call(self,inputs):
        # high level feature:F_y  low level feature:F_x
        Fy,Fx = inputs
        M_chan = self.ChanAtten([Fy,Fx])
        return M_chan

3.Spatia-lChannel Attention Layer

class ChannelSpatialAttention(Layer):
    def __init__(self,outChan,feature_size_low,feature_size_high,feature_dim_high,mode,**kwargs):
        super(ChannelSpatialAttention,self).__init__()
        self.outChan = outChan
        self.feature_size_low = feature_size_low
        self.feature_size_high = feature_size_high
        self.feature_dim_high = feature_dim_high
        self.mode = mode
        self.ChannelAttention = ChannelAttention(outChan=outChan,feature_size_low=feature_size_low,feature_size_high=feature_size_high,feature_dim_high=feature_dim_high,mode=mode)
        self.SpatialAttention = SpatialAttention(outChan=outChan,feature_size_low=feature_size_low,feature_size_high=feature_size_high,feature_dim_high=feature_dim_high,mode=mode)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                'outChan':self.outChan,
                'feature_size_low':self.feature_size_low,
                'feature_size_high':self.feature_size_high,
                'feature_dim_high':self.feature_dim_high,
                'ChannelAttention':self.ChannelAttention,
                'SpatialAttention':self.SpatialAttention,
                'mode':self.mode
            }
        )
        return config
    
    def call(self,inputs):
        # high level feature:F_y  low level feature:F_x
        Fy,Fx = inputs
        M_Chan = self.ChannelAttention([Fy,Fx])
        M_Spatial = self.SpatialAttention([Fy,M_Chan])
        return M_Spatial

References

https://doi.org/10.1016/j.knosys.2021.106754

你可能感兴趣的:(Keras,Modules,深度学习,keras,python,图像处理,计算机视觉)