在Pytorch中可以用torch.unfold, torch.cat和torch.transpose的组合实现im2col操作.
TAKE AWAY:
stride = (1, 1)
kernel_size = (3, 3)
x = torch.arange(0, 25).resize_(5, 5)
y = torch.cat(torch.cat(x.unfold(0, kernel_size[0], stride[0]).unfold(1, kernel_size[1], stride[1]).transpose(0, 2), dim=2).transpose(0, 1), dim=0)
下面以一个简单小矩阵举例详细说明单通道im2col操作:
x = torch.arange(0, 25).resize_(5, 5)
print(x)
0 1 2 3 4
5 6 7 8 9
10 11 12 13 14
15 16 17 18 19
20 21 22 23 24
[torch.FloatTensor of size 5x5]
定义卷积核大小和步长
kernel_size = (3, 3)
stride = (1, 1)
首先使用unfold将其切片成小矩阵, 先横着切:
x = x.unfold(0, kernel, 1)
print(x)
(0 ,.,.) =
0 5 10
1 6 11
2 7 12
3 8 13
4 9 14
(1 ,.,.) =
5 10 15
6 11 16
7 12 17
8 13 18
9 14 19
(2 ,.,.) =
10 15 20
11 16 21
12 17 22
13 18 23
14 19 24
[torch.FloatTensor of size 3x5x3]
再竖着切:
x = x.unfold(1, kernel_size[1], stride[1])
print(x)
(0 ,0 ,.,.) =
0 1 2
5 6 7
10 11 12
(0 ,1 ,.,.) =
1 2 3
6 7 8
11 12 13
(0 ,2 ,.,.) =
2 3 4
7 8 9
12 13 14
(1 ,0 ,.,.) =
5 6 7
10 11 12
15 16 17
(1 ,1 ,.,.) =
6 7 8
11 12 13
16 17 18
(1 ,2 ,.,.) =
7 8 9
12 13 14
17 18 19
(2 ,0 ,.,.) =
10 11 12
15 16 17
20 21 22
(2 ,1 ,.,.) =
11 12 13
16 17 18
21 22 23
(2 ,2 ,.,.) =
12 13 14
17 18 19
22 23 24
[torch.FloatTensor of size 3x3x3x3]
这里要注意, 因为接下来要使用torch.cat做拼接, 但是因为cat操作的一些特点, 需要先用transpose对维度顺序做一下调整, 注意在我这个例子里维度都是3所以可能看不出来, 可以自己做实验试一下维度不相同的情况:
x = x.transpose(0, 2)
(0 ,0 ,.,.) =
0 1 2
5 6 7
10 11 12
(0 ,1 ,.,.) =
1 2 3
6 7 8
11 12 13
(0 ,2 ,.,.) =
2 3 4
7 8 9
12 13 14
(1 ,0 ,.,.) =
5 6 7
10 11 12
15 16 17
(1 ,1 ,.,.) =
6 7 8
11 12 13
16 17 18
(1 ,2 ,.,.) =
7 8 9
12 13 14
17 18 19
(2 ,0 ,.,.) =
10 11 12
15 16 17
20 21 22
(2 ,1 ,.,.) =
11 12 13
16 17 18
21 22 23
(2 ,2 ,.,.) =
12 13 14
17 18 19
22 23 24
[torch.FloatTensor of size 3x3x3x3]
然后用cat拼接一下:
x = torch.cat(x, dim=2)
print(x)
(0 ,.,.) =
0 1 2 5 6 7 10 11 12
5 6 7 10 11 12 15 16 17
10 11 12 15 16 17 20 21 22
(1 ,.,.) =
1 2 3 6 7 8 11 12 13
6 7 8 11 12 13 16 17 18
11 12 13 16 17 18 21 22 23
(2 ,.,.) =
2 3 4 7 8 9 12 13 14
7 8 9 12 13 14 17 18 19
12 13 14 17 18 19 22 23 24
[torch.FloatTensor of size 3x3x9]
这时再用transpose先转置一下:
x = x.transpose(0, 1)
print(x)
(0 ,.,.) =
0 1 2 5 6 7 10 11 12
1 2 3 6 7 8 11 12 13
2 3 4 7 8 9 12 13 14
(1 ,.,.) =
5 6 7 10 11 12 15 16 17
6 7 8 11 12 13 16 17 18
7 8 9 12 13 14 17 18 19
(2 ,.,.) =
10 11 12 15 16 17 20 21 22
11 12 13 16 17 18 21 22 23
12 13 14 17 18 19 22 23 24
[torch.FloatTensor of size 3x3x9]
最后cat一次就完成啦:
x = torch.cat(x, dim=2)
print(x)
0 1 2 5 6 7 10 11 12
1 2 3 6 7 8 11 12 13
2 3 4 7 8 9 12 13 14
5 6 7 10 11 12 15 16 17
6 7 8 11 12 13 16 17 18
7 8 9 12 13 14 17 18 19
10 11 12 15 16 17 20 21 22
11 12 13 16 17 18 21 22 23
12 13 14 17 18 19 22 23 24
[torch.FloatTensor of size 9x9]