注意力机制之BAM

先附程序

import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
    def __init__(self, gate_channel, reduction_ratio=2, num_layers=1):
        super(ChannelGate, self).__init__()
        # self.gate_activation = gate_activation
        self.gate_c = nn.Sequential()
        self.gate_c.add_module( 'flatten', Flatten() )
        gate_channels = [gate_channel]
        gate_channels += [gate_channel // reduction_ratio] * num_layers
        gate_channels += [gate_channel]
        print((gate_channels),12)
        for i in range( len(gate_channels) - 2 ):
            self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(int((hight*width)/4*wei), gate_channels[i+1]) )
            self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) )
            self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() )
        self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[i+1], int((hight*width)/4*wei)*batch) )
    def forward(self, in_tensor):
        print(in_tensor.shape,in_tensor.size(2))
        avg_pool = F.avg_pool2d( in_tensor, 2, stride=2 )
        print(avg_pool.shape,122,in_tensor.shape)
        # self.gate_c( avg_pool )
        return self.gate_c( avg_pool ).reshape(batch,wei,hight,width)#.expand_as(in_tensor)

class SpatialGate(nn.Module):
    def __init__(self, gate_channel, reduction_ratio=4, dilation_conv_num=2, dilation_val=4):
        super(SpatialGate, self).__init__()
        self.gate_s = nn.Sequential()
        self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(wei, wei//reduction_ratio, 3,1,1))
        self.gate_s.add_module( 'gate_s_bn_reduce0',	nn.BatchNorm2d(wei//reduction_ratio) )
        self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() )
        for i in range( dilation_conv_num ):
            self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(wei//reduction_ratio, wei//reduction_ratio, 3,1,1) )
            self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(wei//reduction_ratio) )
            self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() )
        self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(wei//reduction_ratio, wei, 3,1,1) )
    def forward(self, in_tensor):
        print(in_tensor.shape,456)
        return self.gate_s( in_tensor )#.expand_as(in_tensor)
class BAM(nn.Module):
    def __init__(self, gate_channel):
        super(BAM, self).__init__()
        self.channel_att = ChannelGate(gate_channel)
        self.spatial_att = SpatialGate(gate_channel)
    def forward(self,in_tensor):
        print(self.channel_att(in_tensor).shape,4256785)
        att = 1 + torch.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor.reshape(-1,wei,hight,width)) )
        return att * in_tensor
inp=np.random.randint(1,10,(4,64,32,32))
batch,wei,hight,width=inp.shape[0],inp.shape[1],inp.shape[2],inp.shape[3]
# wei #通道数
# batch#batch
# hight=width#尺寸
bam=BAM(wei)
inputs=torch.tensor(inp,dtype=torch.float32)
print(bam(inputs))

注意力机制之BAM_第1张图片

你可能感兴趣的:(CNN,深度学习)