本章代码来自https://github.com/abhi4ssj/squeeze_and_excitation
https://blog.csdn.net/jiachen0212/article/details/80542516这篇博文讲的不错
SE-net的优点在于使用了Global pooling把通道内信息进行归一化了,然后用两个全连接层加sigmod,把各通道的权重进行重新安排,最后再恢复到原有的尺寸。
关键代码
class ChannelSELayer(nn.Module):
"""
Re-implementation of Squeeze-and-Excitation (SE) block described in:
*Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
"""
def __init__(self, num_channels, reduction_ratio=2):
"""
:param num_channels: No of input channels
:param reduction_ratio: By how much should the num_channels should be reduced
"""
super(ChannelSELayer, self).__init__()
num_channels_reduced = num_channels // reduction_ratio
self.reduction_ratio = reduction_ratio
self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, input_tensor):
"""
:param input_tensor: X, shape = (batch_size, num_channels, H, W)
:return: output tensor
"""
batch_size, num_channels, H, W = input_tensor.size()
squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)
# channel excitation
fc_out_1 = self.relu(self.fc1(squeeze_tensor))
fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
a, b = squeeze_tensor.size()
output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1))
return output_tensor
下面这一篇我找se-net代码时发现他代码极其简单,就看了下,论文:Roy, A.G., Navab, N. and Wachinger, C., 2018. Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks. In Proc. MICCAI 2018.
他做的就是把通道上面那套操作搬到空间上用了遍,把各通道的权重重新安排的操作直接换成1*1的卷积。
class SpatialSELayer(nn.Module):
"""
Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in:
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
"""
def __init__(self, num_channels):
"""
:param num_channels: No of input channels
"""
super(SpatialSELayer, self).__init__()
self.conv = nn.Conv2d(num_channels, 1, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_tensor):
"""
:param input_tensor: X, shape = (batch_size, num_channels, H, W)
:return: output_tensor
"""
# spatial squeeze
batch_size, _, a, b = input_tensor.size()
squeeze_tensor = self.sigmoid(self.conv(input_tensor))
# spatial excitation
output_tensor = torch.mul(input_tensor, squeeze_tensor.view(batch_size, 1, a, b))
return output_tensor