1 ShuffleNet的思想
ShuffleNet使用Group convolution和Channel shuffle改进ResNet,可以看作是ResNet的压缩版本。
- Group convolution
- Channel shuffle
ShuffleNet的本质是将卷积运算限制在每个Group内,这样模型的计算量取得了显著的下降。然而导致模型的信息流限制在各个Group内,组与组之间没有信息交换,这会影响模型的表示能力。因此,需要引入组间信息交换的机制,即Channel Shuffle操作。同时Channel Shuffle是可导的,可以实现end-to-end一次性训练网络。
2 核心代码
分组shuffle通道:
def shuffle_channels(x, groups):
"""shuffle channels of a 4-D Tensor"""
batch_size, channels, height, width = x.size()
assert channels % groups == 0
channels_per_group = channels // groups
# split into groups
x = x.view(batch_size, groups, channels_per_group,
height, width)
# transpose 1, 2 axis
x = x.transpose(1, 2).contiguous()
# reshape into orignal
x = x.view(batch_size, channels, height, width)
return x
ShuffleNet的A单元:
class ShuffleNetUnitA(nn.Module):
"""ShuffleNet unit for stride=1"""
def __init__(self, in_channels, out_channels, groups=3):
super(ShuffleNetUnitA, self).__init__()
assert in_channels == out_channels
assert out_channels % 4 == 0
bottleneck_channels = out_channels // 4
self.groups = groups
self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,
1, groups=groups, stride=1)
self.bn2 = nn.BatchNorm2d(bottleneck_channels)
self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,
bottleneck_channels,
3, padding=1, stride=1,
groups=bottleneck_channels)
self.bn4 = nn.BatchNorm2d(bottleneck_channels)
self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,
1, stride=1, groups=groups)
self.bn6 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.group_conv1(x)
out = F.relu(self.bn2(out))
out = shuffle_channels(out, groups=self.groups)
out = self.depthwise_conv3(out)
out = self.bn4(out)
out = self.group_conv5(out)
out = self.bn6(out)
out = F.relu(x + out)
return out
ShuffleNet的B单元:
class ShuffleNetUnitB(nn.Module):
"""ShuffleNet unit for stride=2"""
def __init__(self, in_channels, out_channels, groups=3):
super(ShuffleNetUnitB, self).__init__()
out_channels -= in_channels
assert out_channels % 4 == 0
bottleneck_channels = out_channels // 4
self.groups = groups
self.group_conv1 = nn.Conv2d(in_channels, bottleneck_channels,
1, groups=groups, stride=1)
self.bn2 = nn.BatchNorm2d(bottleneck_channels)
self.depthwise_conv3 = nn.Conv2d(bottleneck_channels,
bottleneck_channels,
3, padding=1, stride=2,
groups=bottleneck_channels)
self.bn4 = nn.BatchNorm2d(bottleneck_channels)
self.group_conv5 = nn.Conv2d(bottleneck_channels, out_channels,
1, stride=1, groups=groups)
self.bn6 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.group_conv1(x)
out = F.relu(self.bn2(out))
out = shuffle_channels(out, groups=self.groups)
out = self.depthwise_conv3(out)
out = self.bn4(out)
out = self.group_conv5(out)
out = self.bn6(out)
x = F.avg_pool2d(x, 3, stride=2, padding=1)
out = F.relu(torch.cat([x, out], dim=1))
return out
3 优缺点分析
缺点:
- Shuffle channel在实现的时候需要大量的指针跳转和Memory set,这本身就是极其耗时的;同时又特别依赖实现细节,导致实际运行速度不会那么理想。
- Shuffle channel规则是人工设计出来的,不是网络自己学出来的。这不符合网络通过负反馈自动学习特征的基本原则,又陷入人工设计特征的老路(如sift/HOG等)。
4 总结
首先介绍了ShuffleNet的基本思想,然后介绍了ShuffleNet的核心代码实现,最后分析了ShuffleNet V1的缺点,指明了改进方向。
最后插一句,想得到未必能做到,能做到未必能做好,能做好未必能产生效益。
告诉自己,想到了就要做到,做到了就要尽力做好,做好了就要想办法变现。