深度学习论文: LEDnet: A lightweight encoder-decoder network for real-time semantic segmentation及其PyTorch实现

class HalfSplit(nn.Module):
    def __init__(self, dim=1):
        super(HalfSplit, self).__init__()
        self.dim = dim

    def forward(self, input):
        splits = torch.chunk(input, 2, dim=self.dim)
        return splits[0], splits[1]

class ChannelShuffle(nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
        N, C, H, W = x.size()
        g = self.groups
        return x.view(N, g, int(C / g), H, W).permute(0, 2, 1, 3, 4).contiguous().view(N, C, H, W)


class SS_nbt(nn.Module):
    def __init__(self, channels, dilation=1, groups=4):
        super(SS_nbt, self).__init__()

        mid_channels = channels // 2
        self.half_split = HalfSplit(dim=1)

        self.first_bottleneck = nn.Sequential(
            ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, padding=[1, 0]),
            ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, padding=[0, 1]),
            ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, dilation=[dilation,1], padding=[dilation, 0]),
            ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, dilation=[1,dilation], padding=[0, dilation]),
        )

        self.second_bottleneck = nn.Sequential(
            ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, padding=[0, 1]),
            ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, padding=[1, 0]),
            ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, dilation=[1,dilation], padding=[0, dilation]),
            ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, dilation=[dilation,1], padding=[dilation, 0]),
        )

        self.channelShuffle = ChannelShuffle(groups)

    def forward(self, x):
        x1, x2 = self.half_split(x)
        x1 = self.first_bottleneck(x1)
        x2 = self.second_bottleneck(x2)
        out = torch.cat([x1, x2], dim=1)
        return self.channelShuffle(out+x)

你可能感兴趣的:(pytorch,深度学习,人工智能,python)