PyTorch入门教学——torchvision中数据集的使用

1、torchvision.datasets

  • datasets是torchvision工具集中的一个工具。
  • 可以理解为调用官方数据集的一种方式,其中有很多开源的数据集,可供我们学习使用。
  • datasets官网:Datasets — Torchvision 0.16 documentation (pytorch.org)
  • PyTorch入门教学——torchvision中数据集的使用_第1张图片 

2、使用

  • 这里以使用CIFAR10中的数据为例。
  • PyTorch入门教学——torchvision中数据集的使用_第2张图片
  • 其中有这个数据集的使用方法和具体介绍。
  • PyTorch入门教学——torchvision中数据集的使用_第3张图片
  • 参数:(每个数据集的参数大致相同)
    • root:数据集下载后存放的目录。
    • train:如果为True,则从训练集创建数据集,否则从测试集创建。
    • transform:接收PIL图像的转换方式,并返回转换后的版本。
    • download:如果为True,则从互联网下载数据集,然后将其放在设置的目录中。如果数据集已下载,则不会再次下载。
  • 代码演示——查看数据集中图片的信息
    • import torchvision
      
      train_set = torchvision.datasets.CIFAR10(root="./Dataset/CIFAR10", train=True, download=True)  # root:数据集要存放在什么位置
      test_set = torchvision.datasets.CIFAR10(root="./Dataset/CIFAR10", train=False, download=True)
      
      print(test_set[0])  # 第一张图片的信息,包含格式和标签
      print(test_set.classes)  # 数据集中所包含的图片类别
      
      img, target = test_set[0]
      print(img)
      print(target)  # 标签
      print(test_set.classes[target])  # 第一张图片的标签为猫
      img.show()  # 显示图片
    • PyTorch入门教学——torchvision中数据集的使用_第4张图片
    • PyTorch入门教学——torchvision中数据集的使用_第5张图片
  • 代码演示——将数据集中的前10张图片在tensorboard中展示出来。
    • import torchvision
      from torch.utils.tensorboard import SummaryWriter
      
      test_set = torchvision.datasets.CIFAR10(
          root="./Dataset/CIFAR10",
          transform=torchvision.transforms.ToTensor(),  # 将图片转换为totensor数据类型
          train=False,
          download=True)
      
      writer = SummaryWriter('logs')  # writer把summary内容写在哪个目录下
      for i in range(10):
          img, target = test_set[i]
          writer.add_image('test_set', img, i)
      
      writer.close()
    • PyTorch入门教学——torchvision中数据集的使用_第6张图片
    • 运行程序后,打开终端,输入下列命令打开tensorboard。
    • tensorboard --logdir=logs --port=6007
    • PyTorch入门教学——torchvision中数据集的使用_第7张图片(该数据集的图片像素为32*32,所以比较模糊)

你可能感兴趣的:(PyTorch,pytorch,人工智能,python)