class SAM(nn.Module):
""" Parallel CBAM """
def __init__(self, in_ch):
super(SAM, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 1),
nn.Sigmoid()
)
def forward(self, x):
""" Spatial Attention Module """
x_attention = self.conv(x)
return x * x_attention
2021-SimAM
import torch
import torch.nn as nn
#ICML2021
class simam_module(torch.nn.Module):
def __init__(self, channels = None, e_lambda = 1e-4):
super(simam_module, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def __repr__(self):
s = self.__class__.__name__ + '('
s += ('lambda=%f)' % self.e_lambda)
return s
@staticmethod
def get_module_name():
return "simam"
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
2021-GAM
import torch.nn as nn
import torch
class GAM_Attention(nn.Module):
def __init__(self, in_channels, out_channels, rate=4):
super(GAM_Attention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(in_channels, int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(in_channels / rate), in_channels)
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
nn.BatchNorm2d(int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid()
out = x * x_spatial_att
return out
if __name__ == '__main__':
x = torch.randn(1, 64, 32, 48)
b, c, h, w = x.shape
net = GAM_Attention(in_channels=c, out_channels=c)
y = net(x)
SE+CBAM
import torch
import torch.nn as nn
import torch.nn.functional as F
###########################################################################################################
class SEModule(nn.Module):
def __init__(self, channels, reduction=16):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc_1 = nn.Conv2d(
channels, channels // reduction, kernel_size=1, padding=0
)
self.relu = nn.ReLU(inplace=True)
self.fc_2 = nn.Conv2d(
channels // reduction, channels, kernel_size=1, padding=0
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
original = x
x = self.avg_pool(x)
x = self.fc_1(x)
x = self.relu(x)
x = self.fc_2(x)
x = self.sigmoid(x)
return original * x
###########################################################################################################
class BasicConv(nn.Module):
def __init__(
self,
in_planes,
out_planes,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
relu=True,
bn=True,
bias=False,
):
super(BasicConv, self).__init__()
self.out_planes = out_planes
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.bn = (
nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
if bn
else None
)
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(
self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"]
):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels),
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type == "avg":
avg_pool = F.avg_pool2d(
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
)
channel_att_raw = self.mlp(avg_pool)
elif pool_type == "max":
max_pool = F.max_pool2d(
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
)
channel_att_raw = self.mlp(max_pool)
elif pool_type == "lp":
lp_pool = F.lp_pool2d(
x,
2,
(x.size(2), x.size(3)),
stride=(x.size(2), x.size * (3)),
)
channel_att_raw = self.mlp(lp_pool)
elif pool_type == "lse":
# LSE pool
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp(lse_pool)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = (
F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
)
return x * scale
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat(
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)),
dim=1,
)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(
2,
1,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
relu=False,
)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = F.sigmoid(x_out)
return x * scale
class CBAM(nn.Module):
def __init__(
self,
gate_channels,
reduction_ratio=16,
pool_types=["avg", "max"],
no_spatial=False,
):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(
gate_channels, reduction_ratio, pool_types
)
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out
ECA
class eca_block(nn.Module):
def __init__(self, channel, b=1, gamma=2):
super(eca_block, self).__init__()
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
y = self.sigmoid(y)
return x * y.expand_as(x)