先附程序
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channel, reduction_ratio=2, num_layers=1):
super(ChannelGate, self).__init__()
# self.gate_activation = gate_activation
self.gate_c = nn.Sequential()
self.gate_c.add_module( 'flatten', Flatten() )
gate_channels = [gate_channel]
gate_channels += [gate_channel // reduction_ratio] * num_layers
gate_channels += [gate_channel]
print((gate_channels),12)
for i in range( len(gate_channels) - 2 ):
self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(int((hight*width)/4*wei), gate_channels[i+1]) )
self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) )
self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() )
self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[i+1], int((hight*width)/4*wei)*batch) )
def forward(self, in_tensor):
print(in_tensor.shape,in_tensor.size(2))
avg_pool = F.avg_pool2d( in_tensor, 2, stride=2 )
print(avg_pool.shape,122,in_tensor.shape)
# self.gate_c( avg_pool )
return self.gate_c( avg_pool ).reshape(batch,wei,hight,width)#.expand_as(in_tensor)
class SpatialGate(nn.Module):
def __init__(self, gate_channel, reduction_ratio=4, dilation_conv_num=2, dilation_val=4):
super(SpatialGate, self).__init__()
self.gate_s = nn.Sequential()
self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(wei, wei//reduction_ratio, 3,1,1))
self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(wei//reduction_ratio) )
self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() )
for i in range( dilation_conv_num ):
self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(wei//reduction_ratio, wei//reduction_ratio, 3,1,1) )
self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(wei//reduction_ratio) )
self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() )
self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(wei//reduction_ratio, wei, 3,1,1) )
def forward(self, in_tensor):
print(in_tensor.shape,456)
return self.gate_s( in_tensor )#.expand_as(in_tensor)
class BAM(nn.Module):
def __init__(self, gate_channel):
super(BAM, self).__init__()
self.channel_att = ChannelGate(gate_channel)
self.spatial_att = SpatialGate(gate_channel)
def forward(self,in_tensor):
print(self.channel_att(in_tensor).shape,4256785)
att = 1 + torch.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor.reshape(-1,wei,hight,width)) )
return att * in_tensor
inp=np.random.randint(1,10,(4,64,32,32))
batch,wei,hight,width=inp.shape[0],inp.shape[1],inp.shape[2],inp.shape[3]
# wei #通道数
# batch#batch
# hight=width#尺寸
bam=BAM(wei)
inputs=torch.tensor(inp,dtype=torch.float32)
print(bam(inputs))