以scale=2为例子,采用了像素点全为0.4,0.5,0.6,0.7的size=55四张图片,观察pixelshuffle后1010图片的像素点排列。
import torch
import torch.nn as nn
####################################
#验证pixelshuffle函数是否正常排列像素点。
####################################
a0 = torch.full([1,1,5,5],1)
a1 = torch.full([1,1,5,5],0.4)
print(a1.shape)
a2 = torch.full([1,1,5,5],0.5)
print(a2.shape)
a3 = torch.full([1,1,5,5],0.6)
print(a3.shape)
a4 = torch.full([1,1,5,5],0.7)
print(a4.shape)
in_12 = torch.cat((a1, a2), 1)
in_34 = torch.cat((a3, a4), 1)
input0 = torch.cat((in_12, in_34), 1)
print(input0.shape)
ps = nn.PixelShuffle(2)
out = ps(input0)
print(out.shape)
print(out)
输出结果为
可以发现,像素点的排列顺序与论文中图示相同。scale>=3的验证方法相同。如果需要按照自己的想法排列像素点,一定要验证输出!!!