Pixel shuffle的原理这里不着重探讨,详细原理请参考论文:Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network。
这里先给出pytorch中pixel shuffle的代码。可以看到,代码的核心部分是permute操作,下面主要讨论pytorch如何通过该操作完成pixel shuffle的。(默认读者已经知道view/permute/contiguous等基础操作)
def pixel_shuffle(input, upscale_factor):
r"""Rearranges elements in a tensor of shape ``[*, C*r^2, H, W]`` to a
tensor of shape ``[C, H*r, W*r]``.
See :class:`~torch.nn.PixelShuffle` for details.
Args:
input (Variable): Input
upscale_factor (int): factor to increase spatial resolution by
Examples::
>>> ps = nn.PixelShuffle(3)
>>> input = autograd.Variable(torch.Tensor(1, 9, 4, 4))
>>> output = ps(input)
>>> print(output.size())
torch.Size([1, 1, 12, 12])
"""
batch_size, channels, in_height, in_width = input.size()
channels //= upscale_factor ** 2
out_height = in_height * upscale_factor
out_width = in_width * upscale_factor
input_view = input.contiguous().view(
batch_size, channels, upscale_factor, upscale_factor,
in_height, in_width)
shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
return shuffle_out.view(batch_size, channels, out_height, out_width)
为了方便理解,下面给出一个放大尺度为3,输入大小为 2 × 2 2\times2 2×2 ,输出通道数为1的示例,即batch_size=1,channels=1。我们将通道自上而下分别编号为0~8。代码中首先将通道分组,将9个通道组成[[0, 1, 2], [3, 4, 5], [6, 7, 8]] 的形式:
input_view = input.contiguous().view(
batch_size, channels, upscale_factor, upscale_factor,
in_height, in_width)
然后,进行permute将通道按我们需要的方式重排,我们想让通道按上述分组顺序依次出现,步长为放大倍数。对于行即按3号通道(通道矩阵行顺序)遍历1行,共需遍历in_width(5号通道)次;对于列即按2号通道(通道矩阵列顺序)遍历,共需遍历in_height(4号通道)次。由于上述操作是行优先的。故整体遍历方式为(4,2,5,3),与代码中permute的方式相同。
shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
注意这里需要contiguous操作保证permute之后张量的连续性,来进行之后的view操作。