Pixelshuffle会将shape为(B, r^2*C, H, W)的tensor变成shape为(B, C, rxH, rxW)的tensor。
0-r^2的通道映射为输出的第一个通道,以此类推。如下例子中,前4个通道映射为输出的第一个通道;中间4个通道映射为输出的第二个通道;最后4个通道映射为输出的第三个通道。
>>> import torch
>>> import torch.nn as nn
>>> ps = nn.PixelShuffle(2)
>>> input = torch.randn(1, 12, 2, 2)
>>> input
tensor([[[[ 0.3157, 2.6184],
[-0.5110, -1.6559]],
[[ 1.5152, -0.3441],
[-0.0120, -1.4397]],
[[ 2.2176, 0.9618],
[ 0.6247, 0.0955]],
[[ 0.2620, 1.3558],
[ 0.9197, 0.8397]],
[[ 0.5672, -0.0619],
[ 0.9506, 0.1088]],
[[ 0.7284, -0.8414],
[ 0.0192, 0.5332]],
[[-0.1117, 0.7233],
[ 0.5228, 0.3788]],
[[ 1.2299, 0.1291],
[ 1.4859, 0.5856]],
[[ 0.8725, 0.4704],
[ 2.0029, 0.6330]],
[[-0.3081, -1.5928],
[ 1.7993, 0.6195]],
[[ 0.5230, 1.8387],
[-0.3246, -1.1609]],
[[-0.6185, -0.0394],
[ 1.1148, -0.3396]]]])
>>> out = ps(input)
>>> out.shape
torch.Size([1, 3, 4, 4])
>>> out[0,0]
tensor([[ 0.3157, 1.5152, 2.6184, -0.3441],
[ 2.2176, 0.2620, 0.9618, 1.3558],
[-0.5110, -0.0120, -1.6559, -1.4397],
[ 0.6247, 0.9197, 0.0955, 0.8397]])
>>> input[0,0:4]
tensor([[[ 0.3157, 2.6184],
[-0.5110, -1.6559]],
[[ 1.5152, -0.3441],
[-0.0120, -1.4397]],
[[ 2.2176, 0.9618],
[ 0.6247, 0.0955]],
[[ 0.2620, 1.3558],
[ 0.9197, 0.8397]]])