最近看pytorch的一个代码,结果中间有一句还是看不太懂,最后过了一些阵子才看懂,在此Mark一下。
代码如下:
def imshow(img,text,should_save=False):
npimg = img.numpy() # 将torch.FloatTensor 转换为numpy
plt.axis("off") # 不显示坐标尺寸
if text:
plt.text(75, 8, text, style='italic',fontweight='bold',
bbox={'facecolor':'white', 'alpha':0.8, 'pad':10}) # facecolor前景色
# pytorch 图片的显示问题
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
解释这句话:plt.imshow(np.transpose(npimg, (1, 2, 0)))。因为在plt.imshow在现实的时候输入的是(imagesize,imagesize,channels),而def imshow(img,text,should_save=False)中,参数img的格式为(channels,imagesize,imagesize),这两者的格式不一致,我们需要调用一次np.transpose函数,即np.transpose(npimg,(1,2,0)),将npimg的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),进行格式的转换后方可进行显示。
总结一下,pytorch在载入数据集是元组tuple的形式,里面包括了数据及标签,其中的数据可以转换为torch .Tensor的形式,方便后面计算使用。在显示数据的时候,需要将torchtensor
在pytorch中,读入图片并进行显示的方式有两种。
将读取出来的torch.FloatTensor转换为numpy,然后将其(1 ,imagesize,imagesize)给reshape一下,变成(imagesize,imagesize)的形式,最后进行显示,上代码:
# dataset的格式为:([torch.FloatTensor of size 1x28x28],3) 其中图片的格式为(1x28x28)图片的标签为3
# 这里我们只取这一张图片本身,先不管它的标签。
img=dataset[0]
# First 将 torch.FloatTensor 转换为 numpy的格式
img=img.numpy()
# Second 将shape(1,28,28)转化为(28,28)
img=img.reshape(28,28)
# Third 调用plt 将图片显示出来
plt.imshow(img,cmap='gray')
plt.show()
#然后就可以显示图片了
img=torchvision.utils.make_grid(img).numpy()
plt.imshow(np.transpose(img,(1,2,0)))
plt.show()
这里用np.transpose(img,(1,2,0))将图片的格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),这样plt.show()就可以显示图片了。
我对此处np.transpose(1,2,0)理解参考此处