双重关系感知注意力机制 Dual Relation-Aware Attention[keras实现 dual attention优化版]

文章目录

  • 前言
  • 一、Compat Position Attention Module紧凑型位置注意力模块
  • 二、Compat Channel Attention Module紧凑型通道注意力模块
  • 三、效果
  • 四、代码实现
    • 1.Pytorch源码(省略了引用库)
    • 2.keras实现


前言

之前看过一篇dual attention做自然图像分割的文章[1],后来看到作者还出了个优化版,叫Dual Relation-Aware Attention[2],主要解决的问题是dual attention计算和存储成本过高的问题(我跑dual attention也是一直OOM)。顺便试着实现了一下,但是不保证准确性,欢迎讨论,指出错误^_=

[1]Dual Attention Network for Scene Segmentation
论文:https://arxiv.org/abs/1809.02983
[2]Scene Segmentation With Dual Relation-Aware Attention Network
论文:https://ieeexplore.ieee.org/abstract/document/9154612/

代码地址(Pytorch):https://github.com/junfu1115/DANet (作者把两个论文的代码放在一个项目里了)
双重关系感知注意力机制 Dual Relation-Aware Attention[keras实现 dual attention优化版]_第1张图片


一、Compat Position Attention Module紧凑型位置注意力模块

双重关系感知注意力机制 Dual Relation-Aware Attention[keras实现 dual attention优化版]_第2张图片
Dual Attention的位置注意力PAM是通过特征图的内积实现的(实际上就是Self Attention,通过矩阵点乘建模像素间的全局关系),但当特征图比较大的时候,需要高昂的GPU内存开销和计算成本,因此作者提出了紧凑型位置注意力模块CPAM,其实现是通过金字塔池化(不同大小的池化)构建了每个像素和几个聚集中心之间的关系,将这些池化特征拼接起来做self attention内积,一定程度上减少了运算量和内存消耗。

二、Compat Channel Attention Module紧凑型通道注意力模块

双重关系感知注意力机制 Dual Relation-Aware Attention[keras实现 dual attention优化版]_第3张图片

作者注意到,使用通道注意力(CAM)模块时,如果特征映射数(通道数)较大时,需要注意计算量的问题。为了解决这个问题,作者提出了一个紧凑型通道注意力模块CCAM来建立每个通道图和通道聚集中心之间的关系。主要实现是先对输入特征通过1x1卷积进行降维, 再对其计算self attention。

三、效果

双重关系感知注意力机制 Dual Relation-Aware Attention[keras实现 dual attention优化版]_第4张图片
可以看出,降维之后的dual attention对算力和内存的消耗还是少了很多的(在我自己的网络上体现为终于不会OOM了,不过我做小样本医学图像分割,也没观察到涨点)。CAM和CCAM在涨点上差不多就不放了,值得注意的是PAM和CPAM,虽然两者涨点相同,但是作者发现PAM对小目标的提升更大,CPAM对大目标的提升更大(表3),感觉是因为池化导致像素间的关系缺失,导致小目标的识别比较受限。
双重关系感知注意力机制 Dual Relation-Aware Attention[keras实现 dual attention优化版]_第5张图片
另外,作者比较了CPAM和CCAM的连接方式,最后发现对输入特征分别并行计算CPAM和CCAM,再拼接最有效。

文章还提出了一个Details of cross-level gating decoder,不过还没仔细研究。

四、代码实现

1.Pytorch源码(省略了引用库)

class CPAMEnc(Module):
    """
    CPAM encoding module
    """
    def __init__(self, in_channels, norm_layer):
        super(CPAMEnc, self).__init__()
        self.pool1 = AdaptiveAvgPool2d(1)
        self.pool2 = AdaptiveAvgPool2d(2)
        self.pool3 = AdaptiveAvgPool2d(3)
        self.pool4 = AdaptiveAvgPool2d(6)

        self.conv1 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),
                                norm_layer(in_channels),
                                ReLU(True))
        self.conv2 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),
                                norm_layer(in_channels),
                                ReLU(True))
        self.conv3 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),
                                norm_layer(in_channels),
                                ReLU(True))
        self.conv4 = Sequential(Conv2d(in_channels, in_channels, 1, bias=False),
                                norm_layer(in_channels),
                                ReLU(True))

    def forward(self, x):
        b, c, h, w = x.size()
        
        feat1 = self.conv1(self.pool1(x)).view(b,c,-1)
        feat2 = self.conv2(self.pool2(x)).view(b,c,-1)
        feat3 = self.conv3(self.pool3(x)).view(b,c,-1)
        feat4 = self.conv4(self.pool4(x)).view(b,c,-1)
        
        return torch.cat((feat1, feat2, feat3, feat4), 2)


class CPAMDec(Module):
    """
    CPAM decoding module
    """
    def __init__(self,in_channels):
        super(CPAMDec,self).__init__()
        self.softmax  = Softmax(dim=-1)
        self.scale = Parameter(torch.zeros(1))

        self.conv_query = Conv2d(in_channels = in_channels , out_channels = in_channels//4, kernel_size= 1) # query_conv2
        self.conv_key = Linear(in_channels, in_channels//4) # key_conv2
        self.conv_value = Linear(in_channels, in_channels) # value2
    def forward(self, x,y):
        """
            inputs :
                x : input feature(N,C,H,W) y:gathering centers(N,K,M)
            returns :
                out : compact position attention feature
                attention map: (H*W)*M
        """
        m_batchsize,C,width ,height = x.size()
        m_batchsize,K,M = y.size()

        proj_query  = self.conv_query(x).view(m_batchsize,-1,width*height).permute(0,2,1)#BxNxd
        proj_key =  self.conv_key(y).view(m_batchsize,K,-1).permute(0,2,1)#BxdxK
        energy =  torch.bmm(proj_query,proj_key)#BxNxK
        attention = self.softmax(energy) #BxNxk

        proj_value = self.conv_value(y).permute(0,2,1) #BxCxK
        out = torch.bmm(proj_value,attention.permute(0,2,1))#BxCxN
        out = out.view(m_batchsize,C,width,height)
        out = self.scale*out + x
        return out


class CCAMDec(Module):
    """
    CCAM decoding module
    """
    def __init__(self):
        super(CCAMDec,self).__init__()
        self.softmax  = Softmax(dim=-1)
        self.scale = Parameter(torch.zeros(1))

    def forward(self, x,y):
        """
            inputs :
                x : input feature(N,C,H,W) y:gathering centers(N,K,H,W)
            returns :
                out : compact channel attention feature
                attention map: K*C
        """
        m_batchsize,C,width ,height = x.size()
        x_reshape =x.view(m_batchsize,C,-1)

        B,K,W,H = y.size()
        y_reshape =y.view(B,K,-1)
        proj_query  = x_reshape #BXC1XN
        proj_key  = y_reshape.permute(0,2,1) #BX(N)XC
        energy =  torch.bmm(proj_query,proj_key) #BXC1XC
        energy_new = torch.max(energy,-1,keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy_new)
        proj_value = y.view(B,K,-1) #BCN
        
        out = torch.bmm(attention,proj_value) #BC1N
        out = out.view(m_batchsize,C,width ,height)

        out = x + self.scale*out
        return out

2.keras实现

本人使用的tensorflow版本为2.8.0,建议使用2.5.0及以上版本
为了使用自适应平均池化(AdaptiveAveragePooling),需要安装tensorflow_addons(当然也可以自己计算池化尺度,然后用普通池化)

# DRANet
import tensorflow as tf
import numpy as np
from keras.layers import *
import tensorflow_addons as tfa

def conv_norm_act(input_tensor, filters, kernel_size , dilation=1, norm_type='batch', act_type='relu'):
    '''
    Conv2d + Normalization(norm_type:str) + Activation(act_type:str)
    '''
    output_tensor = Conv2D(filters, kernel_size, padding='same', dilation_rate=(dilation, dilation), use_bias=False if norm_type is not None else True, kernel_initializer='he_normal')(input_tensor)
    output_tensor = normalization(output_tensor, normalization=norm_type)
    output_tensor = Activation(act_type)(output_tensor)

    return output_tensor

# 仅支持channel last
def cpam_enc(x):
    '''x: input tensor with shape [B, H, W, C]'''
    b, h, w, c = x.shape
    # x = tf.transpose(x, [0, 3, 1, 2])   # must be channel last
    
    feat1 = tfa.layers.AdaptiveAveragePooling2D(output_size=(1, 1))(x)
    feat2 = tfa.layers.AdaptiveAveragePooling2D(output_size=(2, 2))(x)
    feat3 = tfa.layers.AdaptiveAveragePooling2D(output_size=(3, 3))(x)
    feat4 = tfa.layers.AdaptiveAveragePooling2D(output_size=(6, 6))(x)

    feat1 = tf.reshape(tf.transpose(conv_norm_act(feat1, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 1))
    feat2 = tf.reshape(tf.transpose(conv_norm_act(feat2, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 4))
    feat3 = tf.reshape(tf.transpose(conv_norm_act(feat3, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 9))
    feat4 = tf.reshape(tf.transpose(conv_norm_act(feat4, c, 1, 'batch', 'relu'), [0, 3, 1, 2]), (-1, c, 36))

    return concatenate([feat1, feat2, feat3, feat4], 2)


def cpam_dec(x, y):
    '''
    inputs :
        x : input feature(N,H,W,C) y:gathering centers(N,K,M)
    returns :
        out : compact position attention feature
        attention map: (H*W)*M
    '''
    b, h, w, c = x.shape
    b, k, m = y.shape

    # scale = tf.Variable(tf.zeros(1))
    scale = tf.Variable(tf.ones(1))

    proj_query = Conv2D(c//4, 1)(x)
    proj_query = tf.transpose(proj_query, [0, 3, 1, 2])
    proj_query = tf.transpose(tf.reshape(proj_query, (-1, c//4, h*w)), [0, 2, 1])
    proj_key = Dense(c//4)(y)
    proj_key = tf.transpose(tf.reshape(proj_key, (-1, k, c//4)), [0, 2, 1])
    energy = tf.matmul(proj_query, proj_key)
    attention = tf.nn.softmax(energy)

    proj_value = tf.transpose(Dense(c)(y), [0, 2, 1])
    out = tf.matmul(proj_value, tf.transpose(attention, [0, 2, 1]))
    out = tf.reshape(out, (-1, c, h, w))
    out = tf.transpose(out, [0, 2, 3, 1])
    out = out * scale + x

    return out

def ccam_enc(x):
    b, h, w, c = x.shape
    x = conv_norm_act(x, c//8, 1, 'batch', 'relu')
    x = tf.transpose(x, [0, 3, 1, 2])
    return x

def ccam_dec(x, y):
    '''
    inputs:
        x : input feature(N,H,W,C), y:gathering centers(N,K,H,W)
    returns :
        out : compact channel attention feature
        attention map: K*C
    '''
    m_batchsize, height, width, c = x.shape
    x = tf.transpose(x, [0, 3, 1, 2])   # must be channel last
    x_reshape = tf.reshape(x, (-1, c, height*width))

    # scale = tf.Variable(tf.zeros(1))
    scale = tf.Variable(tf.ones(1))

    b, k, h, w = y.shape
    y_reshape = tf.reshape(y, (-1, k, h*w))
    proj_query = x_reshape
    porj_key = tf.transpose(y_reshape, [0, 2, 1])
    energy = tf.matmul(proj_query, porj_key)
    energy_new = tf.reduce_max(energy, -1, keepdims=True)
    energy_new = tf.repeat(energy_new, energy.shape[-1], -1)
    energy_new = energy_new - energy
    attention = tf.nn.softmax(energy_new)
    proj_value = tf.reshape(y, (-1, k, h*w))
    out = tf.matmul(attention, proj_value)
    out = tf.reshape(out, (-1, c, height, width))
    out = x + scale * out
    out = tf.transpose(out, [0, 2, 3, 1])

    return out


def dra_attention(x, filters):
    y1 = cpam_enc(x)
    y2 = ccam_enc(x)
    att1 = cpam_dec(x, y1)
    att2 = ccam_dec(x, y2)
    att = concatenate([att1, att2], -1)	# channel last
    att = conv_norm_act(att, filters, 1, 'batch', 'relu')

    return att

欢迎纠错,指出问题TAT

你可能感兴趣的:(计算机视觉,深度学习,keras,深度学习,人工智能,tensorflow,计算机视觉)