tensorflow2实现coordinate attention

import tensorflow as tf
from tensorflow.keras.layers import (Conv2D,AvgPool2D,Input)


def CoordAtt(x, reduction = 32):

    def coord_act(x):
        tmpx = tf.nn.relu6(x+3) / 6
        x = x * tmpx
        return x

    x_shape = x.get_shape().as_list()
    [b, h, w, c] = x_shape
    x_h = AvgPool2D(pool_size=(1, w), strides = 1)(x)
    x_w = AvgPool2D(pool_size=(h, 1), strides = 1)(x)
    x_w = tf.transpose(x_w, [0, 2, 1, 3])

    y = tf.concat([x_h, x_w], axis=1)
    mip = max(8, c // reduction)
    y = Conv2D(mip, (1, 1), strides=1, activation=coord_act,name='ca_conv1')(y)

    x_h, x_w = tf.split(y, num_or_size_splits=2, axis=1)
    x_w = tf.transpose(x_w, [0, 2, 1, 3])
    a_h = Conv2D(c, (1, 1), strides=1,activation=tf.nn.sigmoid,name='ca_conv2')(x_h)
    a_w = Conv2D(c, (1, 1), strides=1,activation=tf.nn.sigmoid,name='ca_conv3')(x_w)

    out = x * a_h * a_w


    return out


if __name__ == '__main__':
    inputs = Input(shape=(224,224,3))
    outputs = CoordAtt(inputs)
    print(outputs.shape)

tensorflow2实现coordinate attention_第1张图片

 

你可能感兴趣的:(tensorflow2,深度学习,tensorflow,keras,深度学习)