最近使用pytorch的时候遇到nn.Conv2d和F.conv2d中的groups参数设置问题,查阅了一些相关资料发现网上的回答并不是很清晰明朗,所以自己写一篇关于pytorch分组卷积的见解。
关于普通卷积的知识可以直接点击链接多通道RGB卷积。
分组卷积和普通卷积最大的不同就是卷积核在不同通道上卷积后的操作,在生成一个FeatureMap的前提下,普通卷积是在声明与input_channel相同的数量的卷积核,在各个通道上进行卷积过后求和操作,如图操作。
分组卷积操作如图,图中的input_channel=groups,因此卷积操作则是一个卷积核对应一个channel进行卷积,然后在channel维度上进行concat。
上面所述,可能会看的很懵,下面我们来举个例子,假设我们有一个batch=10的分辨率为32x32的RGB图片,那么在pytorch中输入的数据维度则为[10,3,32,32],我们假设卷积核大小为3,输出12个featuremap(必须设置为input_channel的倍数,后面会继续讲到),即input_channel=3,output_channel=12。
如果我们直接进行卷积不分组,那么我们需要声明一个大小为[12,3,3,3]大小的weight(pytorch中[output_channel,input_channel, k_size,k_size]),3通道数量的kernel与输入图进行卷积操作求和生成一个featuremap,那么12个featuremap则对应的12个kernel的参数量是[3,3,3]。
那么如果进行groups=3的分组卷积呢? 那么这个weight的维度为[12,1,3,3]。也就是说我们只需要声明12个大小为[1,3,3]的参数就可以了,如代码所示。
conv = nn.Conv2d(3,12,kernel_size=3,groups=3)
conv.weight.size()
#torch.Size([12, 1, 3, 3])
如果我们设置的output_channel不能整除input_channel会发生什么呢?
conv = nn.Conv2d(3,10,kernel_size=3,groups=3)
conv.weight.size()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in
----> 1 conv = nn.Conv2d(3,10,kernel_size=3,groups=3)
2 conv.weight.size()
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/conv.py in __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode)
325 super(Conv2d, self).__init__(
326 in_channels, out_channels, kernel_size, stride, padding, dilation,
--> 327 False, _pair(0), groups, bias, padding_mode)
328
329 @weak_script_method
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/conv.py in __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
22 raise ValueError('in_channels must be divisible by groups')
23 if out_channels % groups != 0:
---> 24 raise ValueError('out_channels must be divisible by groups')
25 self.in_channels = in_channels
26 self.out_channels = out_channels
ValueError: out_channels must be divisible by groups
会报错。这是因为每个输入通道上的图片不能平均分配给相同数量的kernel进行卷积操作。
那么我们就很清楚了,普通卷积是在用卷积核在各个通道上进行卷积求和,那么每一张featuremap都会包含之前各个通道上的特征信息;而分组卷积则是按照分组来进行卷积融合操作,在各个分组之间进行普通卷积然后进行融合,融合生成的featuremap仅仅包含其对应分组所有通道的特征信息。