class HalfSplit(nn.Module):
def __init__(self, dim=1):
super(HalfSplit, self).__init__()
self.dim = dim
def forward(self, input):
splits = torch.chunk(input, 2, dim=self.dim)
return splits[0], splits[1]
class ChannelShuffle(nn.Module):
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
'''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
N, C, H, W = x.size()
g = self.groups
return x.view(N, g, int(C / g), H, W).permute(0, 2, 1, 3, 4).contiguous().view(N, C, H, W)
class SS_nbt(nn.Module):
def __init__(self, channels, dilation=1, groups=4):
super(SS_nbt, self).__init__()
mid_channels = channels // 2
self.half_split = HalfSplit(dim=1)
self.first_bottleneck = nn.Sequential(
ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, padding=[1, 0]),
ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, padding=[0, 1]),
ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, dilation=[dilation,1], padding=[dilation, 0]),
ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, dilation=[1,dilation], padding=[0, dilation]),
)
self.second_bottleneck = nn.Sequential(
ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, padding=[0, 1]),
ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, padding=[1, 0]),
ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, dilation=[1,dilation], padding=[0, dilation]),
ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, dilation=[dilation,1], padding=[dilation, 0]),
)
self.channelShuffle = ChannelShuffle(groups)
def forward(self, x):
x1, x2 = self.half_split(x)
x1 = self.first_bottleneck(x1)
x2 = self.second_bottleneck(x2)
out = torch.cat([x1, x2], dim=1)
return self.channelShuffle(out+x)