PyTorch中torch.nn.functional.unfold函数使用详解

首先跳到函数定义中,看一下有哪些参数。

 

def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
    """
    input: tensor数据,四维, Batchsize, channel, height, width
    kernel_size: 核大小,决定输出tensor的数目。稍微详细讲
    dilation: 输出形式是否有间隔,稍后详细讲。
    padding:一般是没有用的必要
    stride: 核的滑动步长。稍后详细讲

"""

我觉得没有一张图很难说清楚这个函数想做啥!

假设我们现在有一个张量特征图,其size为[ 1, C, H, W]

PyTorch中torch.nn.functional.unfold函数使用详解_第1张图片

我们想将这个特征图连续的在分辨率维度(H和W)维度取出特征。就像下面这样:

PyTorch中torch.nn.functional.unfold函数使用详解_第2张图片

就是想把输入tensor数据,按照一定的区域(由核的长宽),不断沿着通道维度取出来,由步长指定核滑动的步长,由dilation指定核内区域哪些被跳过。

这里要说明一下,unfold函数的输入数据是四维,但输出是三维的。假设输入数据是[B, C, H, W], 那么输出数据是 [B, C* kH * kW, L], 其中kH是核的高,kW是核宽。 L则是这个高kH宽kW的核能在H*W区域按照指定stride滑动的次数。

L = (H - kH +1) \times (W - kW +1)

上面公式中第一项是指核高kH的情况下,能在高H的特征图上滑动的次数,后一项则是在宽这个维度上。当然默认stride=1

得到的这三维tensor,还需要reshape一下,才能得到上图右边的形式。

B, C_kh_kw, L = data.size()
data = data.permute(0, 2, 1)
data = data.view(B, L, C, kh, kw)

 

下面就进入代码实践环节。假设B等于1。

import torch
from torch.nn import functional as f

x = torch.arange(0, 1*3*15*15).float()
x = x.view(1,3,15,15)
print(x)
x1 = f.unfold(x, kernel_size=3, dilation=1, stride=1)
print(x1.shape)
B, C_kh_kw, L = x1.size()
x1 = x1.permute(0, 2, 1)
x1 = x1.view(B, L, -1, 3, 3)
print(x1)


'''
x的打印的一部分
tensor([[[[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
            11.,  12.,  13.,  14.],
          [ 15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,
            26.,  27.,  28.,  29.],
        ...
           [[225., 226., 227., 228., 229., 230., 231., 232., 233., 234., 235.,
           236., 237., 238., 239.],
          [240., 241., 242., 243., 244., 245., 246., 247., 248., 249., 250.,
           251., 252., 253., 254.],
        ...

           [[450., 451., 452., 453., 454., 455., 456., 457., 458., 459., 460.,
           461., 462., 463., 464.],
          [465., 466., 467., 468., 469., 470., 471., 472., 473., 474., 475.,
           476., 477., 478., 479.],
        ...
       ]]])

X1 的一部分

tensor([[[[[  0.,   1.,   2.],
           [ 15.,  16.,  17.],
           [ 30.,  31.,  32.]],

          [[225., 226., 227.],
           [240., 241., 242.],
           [255., 256., 257.]],

          [[450., 451., 452.],
           [465., 466., 467.],
           [480., 481., 482.]]],


         [[[  1.,   2.,   3.],
           [ 16.,  17.,  18.],
           [ 31.,  32.,  33.]],

          [[226., 227., 228.],
           [241., 242., 243.],
           [256., 257., 258.]],

          [[451., 452., 453.],
           [466., 467., 468.],
           [481., 482., 483.]]],
'''

 

首先X就是15*15,通道是3的特征图,同时这些值是从底到高按顺序reshape的。相当于0-15*15-1 是最上面一层,中间那层的数值是从15*15 到15*15*2-1. 最后一层的数值是从 15*15*2 到 15*15*3-1

现在对x1观察。

x1 就像是把x沿着分辨率维度切开了,而且是隔着一个元素单位就切(stride=1))。切出来的大小是3*3的(kernel size=3),和核高宽一致。

大家可以自行测试stride为2和dilation为2的情况。相信大家一定可以更深刻的理解这个函数。

你可能感兴趣的:(Pytorch)