在Pytorch中实现im2col操作 Implementing im2col in Pytorch

Pytorch中可以用torch.unfold, torch.cattorch.transpose的组合实现im2col操作.

TAKE AWAY:

stride = (1, 1)
kernel_size = (3, 3)

x = torch.arange(0, 25).resize_(5, 5)

y = torch.cat(torch.cat(x.unfold(0, kernel_size[0], stride[0]).unfold(1, kernel_size[1], stride[1]).transpose(0, 2), dim=2).transpose(0, 1), dim=0)

下面以一个简单小矩阵举例详细说明单通道im2col操作:

x = torch.arange(0, 25).resize_(5, 5)
print(x)

  0   1   2   3   4
  5   6   7   8   9
 10  11  12  13  14
 15  16  17  18  19
 20  21  22  23  24
[torch.FloatTensor of size 5x5]

定义卷积核大小和步长

kernel_size = (3, 3)
stride = (1, 1)

首先使用unfold将其切片成小矩阵, 先横着切:

x = x.unfold(0, kernel, 1)
print(x)

(0 ,.,.) = 
   0   5  10
   1   6  11
   2   7  12
   3   8  13
   4   9  14

(1 ,.,.) = 
   5  10  15
   6  11  16
   7  12  17
   8  13  18
   9  14  19

(2 ,.,.) = 
  10  15  20
  11  16  21
  12  17  22
  13  18  23
  14  19  24
[torch.FloatTensor of size 3x5x3]

再竖着切:

x = x.unfold(1, kernel_size[1], stride[1])
print(x)

(0 ,0 ,.,.) = 
   0   1   2
   5   6   7
  10  11  12

(0 ,1 ,.,.) = 
   1   2   3
   6   7   8
  11  12  13

(0 ,2 ,.,.) = 
   2   3   4
   7   8   9
  12  13  14

(1 ,0 ,.,.) = 
   5   6   7
  10  11  12
  15  16  17

(1 ,1 ,.,.) = 
   6   7   8
  11  12  13
  16  17  18

(1 ,2 ,.,.) = 
   7   8   9
  12  13  14
  17  18  19

(2 ,0 ,.,.) = 
  10  11  12
  15  16  17
  20  21  22

(2 ,1 ,.,.) = 
  11  12  13
  16  17  18
  21  22  23

(2 ,2 ,.,.) = 
  12  13  14
  17  18  19
  22  23  24
[torch.FloatTensor of size 3x3x3x3]

这里要注意, 因为接下来要使用torch.cat做拼接, 但是因为cat操作的一些特点, 需要先用transpose对维度顺序做一下调整, 注意在我这个例子里维度都是3所以可能看不出来, 可以自己做实验试一下维度不相同的情况:

x = x.transpose(0, 2)
(0 ,0 ,.,.) = 
   0   1   2
   5   6   7
  10  11  12

(0 ,1 ,.,.) = 
   1   2   3
   6   7   8
  11  12  13

(0 ,2 ,.,.) = 
   2   3   4
   7   8   9
  12  13  14

(1 ,0 ,.,.) = 
   5   6   7
  10  11  12
  15  16  17

(1 ,1 ,.,.) = 
   6   7   8
  11  12  13
  16  17  18

(1 ,2 ,.,.) = 
   7   8   9
  12  13  14
  17  18  19

(2 ,0 ,.,.) = 
  10  11  12
  15  16  17
  20  21  22

(2 ,1 ,.,.) = 
  11  12  13
  16  17  18
  21  22  23

(2 ,2 ,.,.) = 
  12  13  14
  17  18  19
  22  23  24
[torch.FloatTensor of size 3x3x3x3]

然后用cat拼接一下:

x = torch.cat(x, dim=2)
print(x)

(0 ,.,.) = 
   0   1   2   5   6   7  10  11  12
   5   6   7  10  11  12  15  16  17
  10  11  12  15  16  17  20  21  22

(1 ,.,.) = 
   1   2   3   6   7   8  11  12  13
   6   7   8  11  12  13  16  17  18
  11  12  13  16  17  18  21  22  23

(2 ,.,.) = 
   2   3   4   7   8   9  12  13  14
   7   8   9  12  13  14  17  18  19
  12  13  14  17  18  19  22  23  24
[torch.FloatTensor of size 3x3x9]

这时再用transpose先转置一下:

x = x.transpose(0, 1)
print(x)

(0 ,.,.) = 
   0   1   2   5   6   7  10  11  12
   1   2   3   6   7   8  11  12  13
   2   3   4   7   8   9  12  13  14

(1 ,.,.) = 
   5   6   7  10  11  12  15  16  17
   6   7   8  11  12  13  16  17  18
   7   8   9  12  13  14  17  18  19

(2 ,.,.) = 
  10  11  12  15  16  17  20  21  22
  11  12  13  16  17  18  21  22  23
  12  13  14  17  18  19  22  23  24
[torch.FloatTensor of size 3x3x9]

最后cat一次就完成啦:

x = torch.cat(x, dim=2)
print(x)

    0     1     2     5     6     7    10    11    12
    1     2     3     6     7     8    11    12    13
    2     3     4     7     8     9    12    13    14
    5     6     7    10    11    12    15    16    17
    6     7     8    11    12    13    16    17    18
    7     8     9    12    13    14    17    18    19
   10    11    12    15    16    17    20    21    22
   11    12    13    16    17    18    21    22    23
   12    13    14    17    18    19    22    23    24
[torch.FloatTensor of size 9x9]

你可能感兴趣的:(DL)