pytorch中图片显示问题

最近看pytorch的一个代码,结果中间有一句还是看不太懂,最后过了一些阵子才看懂,在此Mark一下。

代码如下:

def imshow(img,text,should_save=False):
    npimg = img.numpy()  # 将torch.FloatTensor 转换为numpy
    plt.axis("off")  # 不显示坐标尺寸
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})  # facecolor前景色
    # pytorch 图片的显示问题
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

解释这句话:plt.imshow(np.transpose(npimg, (1, 2, 0)))。因为在plt.imshow在现实的时候输入的是(imagesize,imagesize,channels),而def imshow(img,text,should_save=False)中,参数img的格式为(channels,imagesize,imagesize),这两者的格式不一致,我们需要调用一次np.transpose函数,即np.transpose(npimg,(1,2,0)),将npimg的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),进行格式的转换后方可进行显示。


总结一下,pytorch在载入数据集是元组tuple的形式,里面包括了数据及标签,其中的数据可以转换为torch .Tensor的形式,方便后面计算使用。在显示数据的时候,需要将torchtensor
在pytorch中,读入图片并进行显示的方式有两种。

方式一

将读取出来的torch.FloatTensor转换为numpy,然后将其(1 ,imagesize,imagesize)给reshape一下,变成(imagesize,imagesize)的形式,最后进行显示,上代码:

# dataset的格式为:([torch.FloatTensor of size 1x28x28],3)  其中图片的格式为(1x28x28)图片的标签为3
# 这里我们只取这一张图片本身,先不管它的标签。
img=dataset[0]
# First 将 torch.FloatTensor 转换为 numpy的格式
img=img.numpy()
# Second 将shape(1,28,28)转化为(28,28)
img=img.reshape(28,28)
# Third 调用plt 将图片显示出来
plt.imshow(img,cmap='gray')
plt.show()
#然后就可以显示图片了
方式二:调用torch的接口
img=torchvision.utils.make_grid(img).numpy()
plt.imshow(np.transpose(img,(1,2,0)))
plt.show()

这里用np.transpose(img,(1,2,0))将图片的格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),这样plt.show()就可以显示图片了。


我对此处np.transpose(1,2,0)理解参考此处

你可能感兴趣的:(python)