torch.nn.functional.interpolate
官方文档: interpolate
Down/up samples the input to either the given size or the given scale_factor
作用: 给定一个feature,将其插值为给定的size大小,或者根据scale factor参数进行缩放
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False)
nearest插值例子
a = torch.tensor([[1.,2.],[3.,4.]])
a = torch.unsqueeze(a, dim=0)
a = torch.unsqueeze(a, dim=0)
b = nn.functional.interpolate(a, size=(4, 4))
'''
a
tensor([[[[1., 2.],
[3., 4.]]]])
torch.Size([1, 1, 2, 2])
b
tensor([[[[1., 1., 2., 2.],
[1., 1., 2., 2.],
[3., 3., 4., 4.],
[3., 3., 4., 4.]]]])
'''
除了将size=(4,4)
外,原来size为2,这里将scale_factor
设为2,也可以实现上述效果
b = nn.functional.interpolate(a, scale_factor=2)
'''
b
tensor([[[[1., 1., 2., 2.],
[1., 1., 2., 2.],
[3., 3., 4., 4.],
[3., 3., 4., 4.]]]])
'''
上面例子,默认的插值为nearest插值,接下来我们看一个bilinear插值的例子
bilinear插值例子
b = nn.functional.interpolate(a, size=(4, 4), mode="bilinear")
'''
b
tensor([[[[1.0000, 1.2500, 1.7500, 2.0000],
[1.5000, 1.7500, 2.2500, 2.5000],
[2.5000, 2.7500, 3.2500, 3.5000],
[3.0000, 3.2500, 3.7500, 4.0000]]]])
/Users/harry/miniconda3/envs/torch/lib/python3.7/site-packages/torch/nn/functional.
py:3455: UserWarning: Default upsampling behavior when mode=bilinear is changed to
align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior
is desired. See the documentation of nn.Upsample for details.
"See the documentation of nn.Upsample for details.".format(mode)
'''
看到双线性bilinear插值的结果和nearest结果不同,这里还报了错,大体意思是说在双线性插值的情况下,如果设置align_corners=True
,输出可能取决于输入大小,并且不会按照比例将输出和输入像素进行对齐,因此默认的将align_corners=False
,可以参考这两个文档: upsample, interpolate
下面给出align_corners=True
的例子
b = nn.functional.interpolate(a, size=(4, 4), mode="bilinear", align_corners=True)
'''
b
tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000]]]])
'''
观察上面两个例子,如果设置align_corners=False
,也就是默认设置,1,2,3,4最后会到四个角落,还是比较规则的比例,但是如果设置align_corners=True
,1,2,3,4不会按照比例,位置会发生变化变得不规则,但是此时的1,2,3,4四周均被元素环绕,并且算上环绕部分,他们是对齐的,即(1, 1.3333, 1.6667, 2.000), (1.6667, 2.0000, 2.3333, 2.6667), (2.3333, 2.6667, 3.0000, 3.3333), (3.0000, 3.3333, 3.6667, 4.0000)这四部分分别看做一个整体,他们是对齐的,下面介绍下align_corners
参数的运作
图来自pytorch论坛,此参数主要解释为原来的输入是否会被插值元素环绕