在Pytorch1.13的官方文档中,关于nn.Conv2d中的groups的作用是这么描述的:
简单来说就是将输入和输出的通道(channel)进行分组,每一组单独进行卷积操作,然后再把结果拼接(concat)起来。
比如输入大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5),输出大小为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5), g r o u p s = 2 groups=2 groups=2。就是将输入的4个channel分成2个2的channel,输出的8个channel分成2个4的channel,每个输入的2个channel和输出的4个channel组成一组,每组做完卷积后的输出大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)。然后把得到的两组输出在channel这个维度上进行concat,得到最后的输出维度为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5)。
但其实这么描述理解起来不够直观,下面我举个例子,先从语言上进行详细的解释,然后再进行代码验证。
符号 | 数值 | 含义 |
---|---|---|
i n p u t _ c h a n n e l input\_channel input_channel | 4 | 输入通道数量 |
o n p u t _ c h a n n e l onput\_channel onput_channel | 8 | 输出通道数量,其实就是卷积核的个数,我们将其看作卷积核的个数会更容易理解 |
b a t c h _ s i z e batch\_size batch_size | 1 | 批量大小为1 |
H , W H, W H,W | 5 | 输入输出的feature大小为5x5 |
i n p u t _ s h a p e input\_shape input_shape | ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5) | 输入的shape,注意我们这里设置输入的所有元素都为1,即输入是一个全1的tensor |
o u t p u t _ s h a p e output\_shape output_shape | ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5) | 输出的shape |
k e r n e l _ s i z e kernel\_size kernel_size | 3 | 卷积核的大小为3x3 |
p a d d i n g padding padding | 1 | 填充长度为1,这里我们使用1填充(即周围补一圈1),而不是0填充 |
s t r i d e stride stride | 1 | 步长为1 |
我们假设输入tensor的shape为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5),输出tensor的shape为: ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5),即我们的卷积核有8个。下面的图由于 b a t c h _ s i z e = 1 batch\_size=1 batch_size=1,所以省略的 b a t c h _ s i z e batch\_size batch_size的维度。
值得注意的是,这里我们手动设置卷积核中元素的值,前4个卷积核的值都设置为1,后4个卷积核的值都设置为2,如下图所示:
这里解释一下为什么 g r o u p s = 1 groups=1 groups=1时 k e r n e l _ s i z e = ( 4 , 3 , 3 ) kernel\_size=(4, 3, 3) kernel_size=(4,3,3), g r o u p s = 2 groups=2 groups=2时 k e r n e l _ s i z e = ( 2 , 3 , 3 ) kernel\_size=(2, 3, 3) kernel_size=(2,3,3):因为 g r o u p s = 2 groups=2 groups=2时,输入和输出都被分成了两组,输入的shape原来为: ( 4 , 5 , 5 ) (4, 5, 5) (4,5,5),被分成了两个 ( 2 , 5 , 5 ) (2, 5, 5) (2,5,5),所以每个 k e r n e l _ s i z e kernel\_size kernel_size也由 ( 4 , 3 , 3 ) (4, 3, 3) (4,3,3)变为 ( 2 , 3 , 3 ) (2, 3, 3) (2,3,3)。
下面我们来看一下 g r o u p s = 1 groups=1 groups=1和 g r o u p s = 2 groups=2 groups=2时计算过程的不同:
这里解释一下:output的前4个channel的每个feature map的所有元素都为36,后4个channel的每个feature map的所有元素都为72,这是因为:
每个输入的 H , W H,W H,W是5x5,加上padding之后是6x6,具体过程如下:
【情况1:groups=2】
此时应当这么算:
为什么output的前4个channel的每个feature map的所有元素都为18,后4个channel的每个feature map的所有元素都为36呢?看了下面的图应该就能理解这个过程了:
实验环境:Python3.7,torch1.10.2
代码:
import os
import torch
import torch.nn as nn
if __name__ == '__main__':
input_dim, output_dim = 4, 8
X = torch.ones(1, input_dim, 5, 5)
# groups = 1
conv1 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=1, bias=False, padding_mode='replicate')
print(f'groups=1时,卷积核的形状为:{conv1.weight.shape}')
with torch.no_grad():
conv1.weight[:4, :, :, :] = torch.ones(4, 4, 3, 3)
conv1.weight[4:, :, :, :] = torch.ones(4, 4, 3, 3) * 2
Y1 = conv1(X)
print(f'结果为:\n{Y1}')
# groups = 2
conv2 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=2, bias=False, padding_mode='replicate')
print(f'groups=2时,卷积核的形状为:{conv2.weight.shape}')
with torch.no_grad():
conv2.weight[:4, :, :, :] = torch.ones(4, 2, 3, 3)
conv2.weight[4:, :, :, :] = torch.ones(4, 2, 3, 3) * 2
Y2 = conv2(X)
print(f'结果为:\n{Y2}')
结果:
groups=1时,卷积核的形状为:torch.Size([8, 4, 3, 3])
结果为:
tensor([[[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.]],
[[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.]],
[[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.]],
[[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.],
[72., 72., 72., 72., 72.]]]])
groups=2时,卷积核的形状为:torch.Size([8, 2, 3, 3])
结果为:
tensor([[[[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.]],
[[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.]],
[[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.]],
[[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]],
[[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.],
[36., 36., 36., 36., 36.]]]])
Process finished with exit code 0
整体流程我手画了个图,我感觉比PPT画的还清楚,可以更好地理解过程
END:)
p.s.:没想到写个博客写了一上午,画图太费时间了!本来上午还有别的事情的。。。只能推到下午再做了0.0