滑动窗口提取特征-torch.unfold的应用

        假设我们希望提取一个矩阵每一个点上3*3窗口的均值或其他特征,直接使用循环的方法速度太慢,在pytorch中可以利用torch.unfold函数大大简化这个流程。
        首先简单说明torch.unfold函数,其作用是按照选定的尺寸与步长来切分矩阵。unfold函数的参数为(dim,size,step),dim代表想要切分的维度,size代表切分块的尺寸,step代表切分的步长。举个例子,原始输入为H=4,W=5的矩阵: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 \begin{matrix} 1&2&3&4&5\\ 6&7&8&9&10\\ 11&12&13&14&15\\ 16&17&18&19&20\\ \end{matrix} 1611162712173813184914195101520,对其执行x.unfold(0,3,1)操作,代表沿着行维度对其做尺寸为3、步长为1的切分,则能得到:

   [[[ 1.,  6., 11.],
     [ 2.,  7., 12.],
     [ 3.,  8., 13.],
     [ 4.,  9., 14.],
     [ 5., 10., 15.]],
     
    [[ 6., 11., 16.],
     [ 7., 12., 17.],
     [ 8., 13., 18.],
     [ 9., 14., 19.],
     [10., 15., 20.]]]

        假设我们在神经网络中需要提取一个图像样本的3窗口均值,样本维度为(N=1,C=1,H=4,W=5),值与上个示例相同,那么我们进行如下操作就可以:

N, C, H, W = x.size()	##shape=(1,1,4,5)
ksize = 3
padvalue = ksize // 2
Ex = F.pad(x, (padvalue, padvalue, padvalue, padvalue), mode='replicate')	##shape=(1,1,6,7)
Ex = Ex.unfold(2, ksize, 1)	##shape=(1,1,4,7,3)
Ex = Ex.unfold(3, ksize, 1)	##shape=(1,1,4,5,3,3)
Ex = Ex.permute(0, 4, 5, 1, 2, 3).contiguous()
Ex = Ex.view(N,ksize*ksize,C,H,W)	##shape=(1,9,1,4,5)
meanX = Ex.mean(dim=1,keepdim=True).squeeze(1)	##shape=(1,1,4,5)

        padding是图像边界的基本操作,在padding后矩阵如下:

   [[[[ 1.,  1.,  2.,  3.,  4.,  5.,  5.],
      [ 1.,  1.,  2.,  3.,  4.,  5.,  5.],
      [ 6.,  6.,  7.,  8.,  9., 10., 10.],
      [11., 11., 12., 13., 14., 15., 15.],
      [16., 16., 17., 18., 19., 20., 20.],
      [16., 16., 17., 18., 19., 20., 20.]]]]

        Ex在调整维度后最终输出如下:

   [[[[[ 1.,  1.,  2.,  3.,  4.],
       [ 1.,  1.,  2.,  3.,  4.],
       [ 6.,  6.,  7.,  8.,  9.],
       [11., 11., 12., 13., 14.]]],


     [[[ 1.,  2.,  3.,  4.,  5.],
       [ 1.,  2.,  3.,  4.,  5.],
       [ 6.,  7.,  8.,  9., 10.],
       [11., 12., 13., 14., 15.]]],


     [[[ 2.,  3.,  4.,  5.,  5.],
       [ 2.,  3.,  4.,  5.,  5.],
       [ 7.,  8.,  9., 10., 10.],
       [12., 13., 14., 15., 15.]]],


     [[[ 1.,  1.,  2.,  3.,  4.],
       [ 6.,  6.,  7.,  8.,  9.],
       [11., 11., 12., 13., 14.],
       [16., 16., 17., 18., 19.]]],


     [[[ 1.,  2.,  3.,  4.,  5.],
       [ 6.,  7.,  8.,  9., 10.],
       [11., 12., 13., 14., 15.],
       [16., 17., 18., 19., 20.]]],


     [[[ 2.,  3.,  4.,  5.,  5.],
       [ 7.,  8.,  9., 10., 10.],
       [12., 13., 14., 15., 15.],
       [17., 18., 19., 20., 20.]]],


     [[[ 6.,  6.,  7.,  8.,  9.],
       [11., 11., 12., 13., 14.],
       [16., 16., 17., 18., 19.],
       [16., 16., 17., 18., 19.]]],


     [[[ 6.,  7.,  8.,  9., 10.],
       [11., 12., 13., 14., 15.],
       [16., 17., 18., 19., 20.],
       [16., 17., 18., 19., 20.]]],


     [[[ 7.,  8.,  9., 10., 10.],
       [12., 13., 14., 15., 15.],
       [17., 18., 19., 20., 20.],
       [17., 18., 19., 20., 20.]]]]]

        因此我们只需要沿着第一个维度求取均值,得到的就会是原始输入3窗口的均值。

你可能感兴趣的:(pytorch,python,图像处理)