pytorch dataset图片显示

pytorch dataset图片显示

  1. plt参数可以是pytorch的tensor,也可以是numpy的array,但是维度必须是x*x*3;
  2. PIL.Image只能转化为x*x*3的numpy数组,不能转化为torch的tensor;
  3. torchvision的transform出来的是3*x*x的tensor,如果要画图需要转化为x*x*3用plt显示。
import torch
import torchvision
import argparse
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

 dataset = torchvision.datasets.ImageFolder(
     root=args.dataset_dir + "/val",
 )

 fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
 image, label = dataset[0]
 image = np.array(image)  # x*x*3
 ax[0].imshow(image)
 ax[0].set_title(str(label))
 ax[0].axis("off")
 trans_image = TransformsSimCLR(args.image_size).train_transform(Image.fromarray(image))  # 3*x*x
 ax[1].imshow(trans_image.permute(1, 2, 0))
 ax[1].axis("off")
 plt.show()

你可能感兴趣的:(machine,learning,python,深度学习,pytorch)