一图说清ShuffleNet中的通道混洗(附两种pytorch实现)

0.看ShuffleNet的通道混洗没看明白,后来在大神博客:https://blog.csdn.net/u011974639/article/details/79200559上看明白了。把自己理解写在下面方便像我这样的小白更清晰理解。

1.上图:通道混洗就是打乱原特征图通道顺序。

一图说清ShuffleNet中的通道混洗(附两种pytorch实现)_第1张图片

上图说的够明白了吧。首先确定自己的特征图通道数多少,再确定组数多少,然后将通道分组后作为输入(input)就可以了。

最终的输出还是组的形式,再将其拼接就是和原来输入一样shape的特征图了。只不过通道被打乱了。

2.pytorch代码:代码是和上面图通道数一一对应的,比对着看更香

代码1:严格按论文的感觉,自己打印下输入和输出对比上面图看 一目了然。

#第一种  严格按论文的感觉
a = torch.randn(1,15,3,3)
batchsize, channels, height, width = a.size()
groups = 3
channels_per_group = int(channels /groups)
x = a.view(batchsize, groups, channels_per_group, height, width)
x = x.transpose(1, 2).contiguous()
x = x.view(batchsize, -1, height, width)

代码2:写法不同,但输出和上面一模一样

#第二种  
x = torch.randn(1,15,3,3)
N, C, H, W = x.size()
groups = 3
out = x.view(N, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous().view(N, C, H, W)

参考文献:1. https://blog.csdn.net/u011974639/article/details/79200559 

你可能感兴趣的:(机器学习)