Pytorch显示图像

借鉴:http://www.manongjc.com/article/4333.html

pytorch 载入的数据集是元组tuple 形式,里面包括了数据及标签(train_data,label),其中的train_data数据可以转换为torch.Tensor形式,方便后面计算使用。

img = torchvision.utils.make_grid(dataset[1][0]).numpy()

plt.imshow(np.transpose(img,(1,2,0)))

plt.show()

np.transpose 是因为plt.imshow在显示 时候输入的是(imgsize,imgsieze,channels),而这里得到的img是(3,200,200)的格式,所以进行了转换,才能显示,如cifar-10

你可能感兴趣的:(Pytorch显示图像)