[Pytorch] Pixel Shuffle

Pixel Shuffle在Pytorch中的实现

Pixel shuffle的原理这里不着重探讨,详细原理请参考论文:Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network。

这里先给出pytorch中pixel shuffle的代码。可以看到,代码的核心部分是permute操作,下面主要讨论pytorch如何通过该操作完成pixel shuffle的。(默认读者已经知道view/permute/contiguous等基础操作)

def pixel_shuffle(input, upscale_factor):
    r"""Rearranges elements in a tensor of shape ``[*, C*r^2, H, W]`` to a
    tensor of shape ``[C, H*r, W*r]``.

    See :class:`~torch.nn.PixelShuffle` for details.

    Args:
        input (Variable): Input
        upscale_factor (int): factor to increase spatial resolution by

    Examples::

        >>> ps = nn.PixelShuffle(3)
        >>> input = autograd.Variable(torch.Tensor(1, 9, 4, 4))
        >>> output = ps(input)
        >>> print(output.size())
        torch.Size([1, 1, 12, 12])
    """
    batch_size, channels, in_height, in_width = input.size()
    channels //= upscale_factor ** 2

    out_height = in_height * upscale_factor
    out_width = in_width * upscale_factor

    input_view = input.contiguous().view(
        batch_size, channels, upscale_factor, upscale_factor,
        in_height, in_width)

    shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
    return shuffle_out.view(batch_size, channels, out_height, out_width)

为了方便理解,下面给出一个放大尺度为3,输入大小为 2 × 2 2\times2 2×2 ,输出通道数为1的示例,即batch_size=1,channels=1。我们将通道自上而下分别编号为0~8。代码中首先将通道分组,将9个通道组成[[0, 1, 2], [3, 4, 5], [6, 7, 8]] 的形式:

input_view = input.contiguous().view(
        batch_size, channels, upscale_factor, upscale_factor,
        in_height, in_width)

[Pytorch] Pixel Shuffle_第1张图片
然后,进行permute将通道按我们需要的方式重排,我们想让通道按上述分组顺序依次出现,步长为放大倍数。对于行即按3号通道(通道矩阵行顺序)遍历1行,共需遍历in_width(5号通道)次;对于列即按2号通道(通道矩阵列顺序)遍历,共需遍历in_height(4号通道)次。由于上述操作是行优先的。故整体遍历方式为(4,2,5,3),与代码中permute的方式相同。

shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()

注意这里需要contiguous操作保证permute之后张量的连续性,来进行之后的view操作。

你可能感兴趣的:(Pytorch)