假设我们希望提取一个矩阵每一个点上3*3窗口的均值或其他特征,直接使用循环的方法速度太慢,在pytorch中可以利用torch.unfold函数大大简化这个流程。
首先简单说明torch.unfold函数,其作用是按照选定的尺寸与步长来切分矩阵。unfold函数的参数为(dim,size,step),dim代表想要切分的维度,size代表切分块的尺寸,step代表切分的步长。举个例子,原始输入为H=4,W=5的矩阵: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 \begin{matrix} 1&2&3&4&5\\ 6&7&8&9&10\\ 11&12&13&14&15\\ 16&17&18&19&20\\ \end{matrix} 1611162712173813184914195101520,对其执行x.unfold(0,3,1)操作,代表沿着行维度对其做尺寸为3、步长为1的切分,则能得到:
[[[ 1., 6., 11.],
[ 2., 7., 12.],
[ 3., 8., 13.],
[ 4., 9., 14.],
[ 5., 10., 15.]],
[[ 6., 11., 16.],
[ 7., 12., 17.],
[ 8., 13., 18.],
[ 9., 14., 19.],
[10., 15., 20.]]]
假设我们在神经网络中需要提取一个图像样本的3窗口均值,样本维度为(N=1,C=1,H=4,W=5),值与上个示例相同,那么我们进行如下操作就可以:
N, C, H, W = x.size() ##shape=(1,1,4,5)
ksize = 3
padvalue = ksize // 2
Ex = F.pad(x, (padvalue, padvalue, padvalue, padvalue), mode='replicate') ##shape=(1,1,6,7)
Ex = Ex.unfold(2, ksize, 1) ##shape=(1,1,4,7,3)
Ex = Ex.unfold(3, ksize, 1) ##shape=(1,1,4,5,3,3)
Ex = Ex.permute(0, 4, 5, 1, 2, 3).contiguous()
Ex = Ex.view(N,ksize*ksize,C,H,W) ##shape=(1,9,1,4,5)
meanX = Ex.mean(dim=1,keepdim=True).squeeze(1) ##shape=(1,1,4,5)
padding是图像边界的基本操作,在padding后矩阵如下:
[[[[ 1., 1., 2., 3., 4., 5., 5.],
[ 1., 1., 2., 3., 4., 5., 5.],
[ 6., 6., 7., 8., 9., 10., 10.],
[11., 11., 12., 13., 14., 15., 15.],
[16., 16., 17., 18., 19., 20., 20.],
[16., 16., 17., 18., 19., 20., 20.]]]]
Ex在调整维度后最终输出如下:
[[[[[ 1., 1., 2., 3., 4.],
[ 1., 1., 2., 3., 4.],
[ 6., 6., 7., 8., 9.],
[11., 11., 12., 13., 14.]]],
[[[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15.]]],
[[[ 2., 3., 4., 5., 5.],
[ 2., 3., 4., 5., 5.],
[ 7., 8., 9., 10., 10.],
[12., 13., 14., 15., 15.]]],
[[[ 1., 1., 2., 3., 4.],
[ 6., 6., 7., 8., 9.],
[11., 11., 12., 13., 14.],
[16., 16., 17., 18., 19.]]],
[[[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15.],
[16., 17., 18., 19., 20.]]],
[[[ 2., 3., 4., 5., 5.],
[ 7., 8., 9., 10., 10.],
[12., 13., 14., 15., 15.],
[17., 18., 19., 20., 20.]]],
[[[ 6., 6., 7., 8., 9.],
[11., 11., 12., 13., 14.],
[16., 16., 17., 18., 19.],
[16., 16., 17., 18., 19.]]],
[[[ 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15.],
[16., 17., 18., 19., 20.],
[16., 17., 18., 19., 20.]]],
[[[ 7., 8., 9., 10., 10.],
[12., 13., 14., 15., 15.],
[17., 18., 19., 20., 20.],
[17., 18., 19., 20., 20.]]]]]
因此我们只需要沿着第一个维度求取均值,得到的就会是原始输入3窗口的均值。