论文中有两种,上方的为PSA_P, 下方的为PSA_S。
import tensorflow as tf
def kaiming_init(module, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
module.kernel_initializer = tf.keras.initializers.he_uniform() # l
else:
module.kernel_initializer = tf.keras.initializers.he_normal() # z
if hasattr(module, 'bias') and module.bias is not None:
module.kernel_initializer = tf.keras.initializers.constant() # z
class PSA_p(tf.keras.Model):
def __init__(self, planes, data_format='channels_last'):
"""
:param planes: 输入的通道数
:param data_format: 数据格式,默认为 channels_last, 可选 channels_first
"""
super(PSA_p, self).__init__()
self.data_format = data_format
self.planes = planes
self.out_planes = planes // 2
self.conv_q_left = tf.keras.layers.Conv2D(filters=1, kernel_size=1, strides=1, padding='valid', use_bias=False)
self.conv_v_left = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid',
use_bias=False)
self.conv_up_left = tf.keras.layers.Conv2D(filters=self.planes, kernel_size=1, strides=1, padding='valid',
use_bias=False)
self.softmax_left = tf.keras.layers.Activation(activation=tf.keras.activations.softmax)
self.sigmoid_left = tf.keras.layers.Activation(activation=tf.keras.activations.sigmoid)
self.conv_q_right = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid', use_bias=False)
self.conv_v_right = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid', use_bias=False)
self.global_pool = tf.keras.layers.GlobalAvgPool2D(keepdims=True)
self.softmax_right = tf.keras.layers.Activation(activation=tf.keras.activations.softmax)
self.sigmoid_right = tf.keras.layers.Activation(activation=tf.keras.activations.sigmoid)
self.reset_parameters() # l
def reset_parameters(self): # z
kaiming_init(self.conv_v_right)
kaiming_init(self.conv_q_right)
kaiming_init(self.conv_v_left)
kaiming_init(self.conv_q_left)
def channel_pool(self, inputs): # z
if self.data_format == 'channels_first':
inputs = tf.transpose(inputs, perm=[0, 2, 3, 1]) # (B, C, H, W) -> (B, H, W, C)
# (B, H, W, IC) -> (B, H, W, IC/2)
input_x = self.conv_v_left(inputs)
B, H, W, C = tf.shape(input_x)
# (B, H, W, C) -> (B, H*W, C)
input_x = tf.reshape(input_x, shape=(B, H * W, C))
# (B, H, W, IC) -> (B, H, W, 1)
context_mask = self.conv_q_left(inputs)
# (B, H, W, 1) -> (B, H*W, 1)
context_mask = tf.reshape(context_mask, shape=(B, H * W, 1))
# (B, H*W, 1) -> (B, H*W, 1)
context_mask = self.softmax_left(context_mask)
# (B, C, H*W) 点乘 (B, H*W, 1) -> (B, C, 1)
context = tf.matmul(a=tf.transpose(input_x, perm=[0, 2, 1]), b=context_mask)
# (B, C, 1) -> (B, C, 1, 1)
context = tf.expand_dims(context, axis=-1)
# (B, C, 1, 1) -> (B, 1, 1, C)
context = tf.transpose(context, perm=[0, 2, 3, 1])
# (B, 1, 1, C) -> (B, 1, 1, OC)
context = self.conv_up_left(context) # 恢复输入时的通道数
# (B, 1, 1, OC) -> (B, 1, 1, OC)
mask_ch = self.sigmoid_left(context)
out = inputs * mask_ch
if self.data_format == 'channels_first':
# (B, H, W, C) -> (B, C, H, W)
out = tf.transpose(out, [0, 3, 1, 2])
return out
def spatial_pool(self, inputs):
if self.data_format == 'channels_first':
inputs = tf.transpose(inputs, perm=[0, 2, 3, 1]) # (B, C, H, W) -> (B, H, W, C)
# (B, H, W, C) -> (B, H, W, C/2)
g_x = self.conv_q_right(inputs)
# (B, H, W, C/2)-> (B, 1, 1, C/2)
avg_x = self.global_pool(g_x)
B, H, W, C = tf.shape(avg_x)
# (B, 1, 1, C/2) -> (B, 1, C/2)
avg_x = tf.reshape(tensor=avg_x, shape=(B, H*W, C))
# (B, H, W, C) -> (B, H, W, C/2)
g_v = self.conv_v_right(inputs)
V_B, V_H, V_W, V_C = tf.shape(g_v)
# (B, H, W, C/2) -> (B, H*W, C/2)
theta_x = tf.reshape(tensor=g_v, shape=(V_B, V_H*V_W, V_C))
# (B, 1, C/2) * (B, C/2, H*W) -> (B, 1, H*W)
context = tf.matmul(avg_x, tf.transpose(a=theta_x, perm=[0, 2, 1]))
# (B, 1, H*W) -> (B, 1, H*W)
context = self.softmax_right(context) # 虽然论文中的图片,是avg_x先softmax,再点乘,但是代码中却是先点乘,再softmax
# (B, 1, H*W) -> (B, 1, H, W)
context = tf.reshape(context, shape=(V_B, 1, V_H, V_W))
# (B, 1, H, W) -> (B, 1, H, W)
context = self.sigmoid_right(context)
# (B, 1, H, W) -> (B, H, W, 1)
context = tf.transpose(context, [0, 2, 3, 1])
out = inputs * context
if self.data_format == 'channels_first':
out = tf.transpose(out, [0, 3, 1, 2])
return out
def call(self, inputs):
# 空间注意
context_spatial = self.spatial_pool(inputs)
# 通道注意
context_channel = self.channel_pool(inputs)
out = context_spatial + context_channel
return out
class PSA_s(tf.keras.Model):
def __init__(self, planes, data_format='channels_last'):
"""
:param planes: 输入的通道数
:param data_format: 数据格式,默认为 channels_last, 可选 channels_first
"""
super(PSA_s, self).__init__()
self.data_format = data_format
self.planes = planes
self.out_planes = planes // 2
ratio = 4
self.conv_q_left = tf.keras.layers.Conv2D(filters=1, kernel_size=1, strides=1, padding='valid', use_bias=False)
self.conv_v_left = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid',
use_bias=False)
self.conv_up_left = tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=self.out_planes//ratio, kernel_size=1, strides=1, padding='valid'),
tf.keras.layers.LayerNormalization(),
tf.keras.layers.Activation(tf.keras.activations.relu),
tf.keras.layers.Conv2D(filters=self.planes, kernel_size=1, strides=1, padding='valid')
])
self.softmax_left = tf.keras.layers.Activation(activation=tf.keras.activations.softmax)
self.sigmoid_left = tf.keras.layers.Activation(activation=tf.keras.activations.sigmoid)
self.conv_q_right = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid',
use_bias=False)
self.conv_v_right = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid',
use_bias=False)
self.global_pool = tf.keras.layers.GlobalAvgPool2D(keepdims=True)
self.softmax_right = tf.keras.layers.Activation(activation=tf.keras.activations.softmax)
self.sigmoid_right = tf.keras.layers.Activation(activation=tf.keras.activations.sigmoid)
def reset_parameters(self):
kaiming_init(self.conv_v_right)
kaiming_init(self.conv_q_right)
kaiming_init(self.conv_v_left)
kaiming_init(self.conv_q_left)
def channel_pool(self, inputs):
if self.data_format == 'channels_first':
inputs = tf.transpose(inputs, perm=[0, 2, 3, 1]) # (B, C, H, W) -> (B, H, W, C)
# (B, H, W, IC) -> (B, H, W, IC/2)
input_x = self.conv_v_left(inputs)
B, H, W, C = tf.shape(input_x)
# (B, H, W, C) -> (B, H*W, C)
input_x = tf.reshape(input_x, shape=(B, H * W, C))
# (B, H, W, IC) -> (B, H, W, 1)
context_mask = self.conv_q_left(inputs)
# (B, H, W, 1) -> (B, H*W, 1)
context_mask = tf.reshape(context_mask, shape=(B, H * W, 1))
# (B, H*W, 1) -> (B, H*W, 1)
context_mask = self.softmax_left(context_mask)
# (B, C, H*W) 点乘 (B, H*W, 1) -> (B, C, 1)
context = tf.matmul(a=tf.transpose(input_x, perm=[0, 2, 1]), b=context_mask)
# (B, C, 1) -> (B, C, 1, 1)
context = tf.expand_dims(context, axis=-1)
# (B, C, 1, 1) -> (B, 1, 1, C)
context = tf.transpose(context, perm=[0, 2, 3, 1])
# (B, 1, 1, C) -> (B, 1, 1, OC)
context = self.conv_up_left(context) # 恢复输入时的通道数
# (B, 1, 1, OC) -> (B, 1, 1, OC)
mask_ch = self.sigmoid_left(context)
out = inputs * mask_ch
if self.data_format == 'channels_first':
# (B, H, W, C) -> (B, C, H, W)
out = tf.transpose(out, [0, 3, 1, 2])
return out
def spatial_pool(self, inputs):
if self.data_format == 'channels_first':
inputs = tf.transpose(inputs, perm=[0, 2, 3, 1]) # (B, C, H, W) -> (B, H, W, C)
# (B, H, W, C) -> (B, H, W, C/2)
g_x = self.conv_q_right(inputs)
# (B, H, W, C/2)-> (B, 1, 1, C/2)
avg_x = self.global_pool(g_x)
B, H, W, C = tf.shape(avg_x)
# (B, 1, 1, C/2) -> (B, 1, C/2)
avg_x = tf.reshape(tensor=avg_x, shape=(B, H * W, C))
# (B, H, W, C) -> (B, H, W, C/2)
g_v = self.conv_v_right(inputs)
V_B, V_H, V_W, V_C = tf.shape(g_v)
# (B, H, W, C/2) -> (B, H*W, C/2)
theta_x = tf.reshape(tensor=g_v, shape=(V_B, V_H * V_W, V_C))
# (B, 1, C/2) * (B, C/2, H*W) -> (B, 1, H*W)
context = tf.matmul(avg_x, tf.transpose(a=theta_x, perm=[0, 2, 1]))
# (B, 1, H*W) -> (B, 1, H*W)
context = self.softmax_right(context) # 虽然论文中的图片,是avg_x先softmax,再点乘,但是代码中却是先点乘,再softmax
# (B, 1, H*W) -> (B, 1, H, W)
context = tf.reshape(context, shape=(V_B, 1, V_H, V_W))
# (B, 1, H, W) -> (B, 1, H, W)
context = self.sigmoid_right(context)
# (B, 1, H, W) -> (B, H, W, 1)
context = tf.transpose(context, [0, 2, 3, 1])
out = inputs * context
if self.data_format == 'channels_first':
out = tf.transpose(out, [0, 3, 1, 2])
return out
def call(self, inputs):
# 空间注意
out = self.spatial_pool(inputs)
# 通道注意
out = self.channel_pool(out)
return out
if __name__ == '__main__':
a = tf.zeros(shape=(4, 224, 224, 16))
psa_p = PSA_p(planes=16)(a)
print("PSA_p output shape = {}".format(tf.shape(psa_p).numpy()))
psa_s = PSA_s(planes=16)(a)
print("PSA_s output shape = {}".format(tf.shape(psa_s).numpy()))