se-net,spatial se-net

本章代码来自https://github.com/abhi4ssj/squeeze_and_excitation
https://blog.csdn.net/jiachen0212/article/details/80542516这篇博文讲的不错

se-net,spatial se-net_第1张图片
SE-NET

SE-net的优点在于使用了Global pooling把通道内信息进行归一化了,然后用两个全连接层加sigmod,把各通道的权重进行重新安排,最后再恢复到原有的尺寸。
关键代码


class ChannelSELayer(nn.Module):
    """
    Re-implementation of Squeeze-and-Excitation (SE) block described in:
        *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
    """

    def __init__(self, num_channels, reduction_ratio=2):
        """
        :param num_channels: No of input channels
        :param reduction_ratio: By how much should the num_channels should be reduced
        """
        super(ChannelSELayer, self).__init__()
        num_channels_reduced = num_channels // reduction_ratio
        self.reduction_ratio = reduction_ratio
        self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
        self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor):
        """
        :param input_tensor: X, shape = (batch_size, num_channels, H, W)
        :return: output tensor
        """
        batch_size, num_channels, H, W = input_tensor.size()
        squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2)

        # channel excitation
        fc_out_1 = self.relu(self.fc1(squeeze_tensor))
        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))

        a, b = squeeze_tensor.size()
        output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1))
        return output_tensor

下面这一篇我找se-net代码时发现他代码极其简单,就看了下,论文:Roy, A.G., Navab, N. and Wachinger, C., 2018. Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks. In Proc. MICCAI 2018.

他做的就是把通道上面那套操作搬到空间上用了遍,把各通道的权重重新安排的操作直接换成1*1的卷积。

class SpatialSELayer(nn.Module):
    """
    Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in:
        *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
    """

    def __init__(self, num_channels):
        """
        :param num_channels: No of input channels
        """
        super(SpatialSELayer, self).__init__()
        self.conv = nn.Conv2d(num_channels, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor):
        """
        :param input_tensor: X, shape = (batch_size, num_channels, H, W)
        :return: output_tensor
        """
        # spatial squeeze
        batch_size, _, a, b = input_tensor.size()
        squeeze_tensor = self.sigmoid(self.conv(input_tensor))

        # spatial excitation
        output_tensor = torch.mul(input_tensor, squeeze_tensor.view(batch_size, 1, a, b))

         return output_tensor

你可能感兴趣的:(se-net,spatial se-net)