首先跳到函数定义中,看一下有哪些参数。
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
"""
input: tensor数据,四维, Batchsize, channel, height, width
kernel_size: 核大小,决定输出tensor的数目。稍微详细讲
dilation: 输出形式是否有间隔,稍后详细讲。
padding:一般是没有用的必要
stride: 核的滑动步长。稍后详细讲
"""
我觉得没有一张图很难说清楚这个函数想做啥!
假设我们现在有一个张量特征图,其size为[ 1, C, H, W]
我们想将这个特征图连续的在分辨率维度(H和W)维度取出特征。就像下面这样:
就是想把输入tensor数据,按照一定的区域(由核的长宽),不断沿着通道维度取出来,由步长指定核滑动的步长,由dilation指定核内区域哪些被跳过。
这里要说明一下,unfold函数的输入数据是四维,但输出是三维的。假设输入数据是[B, C, H, W], 那么输出数据是 [B, C* kH * kW, L], 其中kH是核的高,kW是核宽。 L则是这个高kH宽kW的核能在H*W区域按照指定stride滑动的次数。
上面公式中第一项是指核高kH的情况下,能在高H的特征图上滑动的次数,后一项则是在宽这个维度上。当然默认stride=1
得到的这三维tensor,还需要reshape一下,才能得到上图右边的形式。
B, C_kh_kw, L = data.size()
data = data.permute(0, 2, 1)
data = data.view(B, L, C, kh, kw)
下面就进入代码实践环节。假设B等于1。
import torch
from torch.nn import functional as f
x = torch.arange(0, 1*3*15*15).float()
x = x.view(1,3,15,15)
print(x)
x1 = f.unfold(x, kernel_size=3, dilation=1, stride=1)
print(x1.shape)
B, C_kh_kw, L = x1.size()
x1 = x1.permute(0, 2, 1)
x1 = x1.view(B, L, -1, 3, 3)
print(x1)
'''
x的打印的一部分
tensor([[[[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.,
11., 12., 13., 14.],
[ 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
26., 27., 28., 29.],
...
[[225., 226., 227., 228., 229., 230., 231., 232., 233., 234., 235.,
236., 237., 238., 239.],
[240., 241., 242., 243., 244., 245., 246., 247., 248., 249., 250.,
251., 252., 253., 254.],
...
[[450., 451., 452., 453., 454., 455., 456., 457., 458., 459., 460.,
461., 462., 463., 464.],
[465., 466., 467., 468., 469., 470., 471., 472., 473., 474., 475.,
476., 477., 478., 479.],
...
]]])
X1 的一部分
tensor([[[[[ 0., 1., 2.],
[ 15., 16., 17.],
[ 30., 31., 32.]],
[[225., 226., 227.],
[240., 241., 242.],
[255., 256., 257.]],
[[450., 451., 452.],
[465., 466., 467.],
[480., 481., 482.]]],
[[[ 1., 2., 3.],
[ 16., 17., 18.],
[ 31., 32., 33.]],
[[226., 227., 228.],
[241., 242., 243.],
[256., 257., 258.]],
[[451., 452., 453.],
[466., 467., 468.],
[481., 482., 483.]]],
'''
首先X就是15*15,通道是3的特征图,同时这些值是从底到高按顺序reshape的。相当于0-15*15-1 是最上面一层,中间那层的数值是从15*15 到15*15*2-1. 最后一层的数值是从 15*15*2 到 15*15*3-1
现在对x1观察。
x1 就像是把x沿着分辨率维度切开了,而且是隔着一个元素单位就切(stride=1))。切出来的大小是3*3的(kernel size=3),和核高宽一致。
大家可以自行测试stride为2和dilation为2的情况。相信大家一定可以更深刻的理解这个函数。