PyTorch入门教学——加载数据(Dataset)

1、简介

  • PyTorch中如何读取数据主要涉及到两个类,分别为Dataset和Dataloader。
    • Dataset:创建可被Pytorch使用的数据集
    • Dataloader:向模型传递数据
  • 本文主要讲解Dataset的使用方法。

2、Dataset

2.1、查看使用方法

  • 打开Anaconda Prompt,进入pytorch虚拟环境(conda activate pytorch),输入下面命令,打开Jupyter。(使用Jupyter输出的结果更加清晰)
  • 新建一个文件(可自行选择创建位置)。
    • PyTorch入门教学——加载数据(Dataset)_第1张图片
  • 输入下面指令,按Shift+回车运行。
    • PyTorch入门教学——加载数据(Dataset)_第2张图片
    • 也可以输入下列指令。
    • PyTorch入门教学——加载数据(Dataset)_第3张图片
  • 有下列描述可知,Dataset是一个抽象类,所有的数据集都需要继承这个类。并且所有子类都需要重写__getitem__方法来获取每一个数据的标签。
  • PyTorch入门教学——加载数据(Dataset)_第4张图片

2.2、应用

  • 使用PyCharm打开pytorch项目。如果没有,请参考:PyTorch入门教学——使用PyCharm创建一个PyTorch项目-CSDN博客,创建一个。
  • 新建一个python文件。
    • PyTorch入门教学——加载数据(Dataset)_第5张图片
  • 数据集下载:https://download.pytorch.org/tutorial/hymenoptera_data.zip,将下好的数据集放入pytorch项目中。
    • 该数据集分为训练数据集和验证数据集。
    • 两个数据集中包含了蚂蚁和蜜蜂的图片,可以用来做二分分类,识别图片为蚂蚁还是蜜蜂。
  • 打开read_data.py,写入下列代码。
    • from torch.utils.data import Dataset
      from PIL import Image  # 获取图片
      import os  # 提供一些方法
      
      
      class MyData(Dataset):  # 继承Dataset
          # 初始化,为整个类提供全局变量
          def __init__(self, root_dir, label_dir):
              self.root_dir = root_dir  # 将这些变量赋给类,供后续函数使用
              self.label_dir = label_dir
              self.path = os.path.join(self.root_dir, self.label_dir)  # 将两个路径拼接
              self.img_path = os.listdir(self.path)  # 将路径中的所有文件名转为列表
      
          # 获取图片和标签
          def __getitem__(self, item):
              img_name = self.img_path[item]  # 获取每一个图片名称
              img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 获取每一张图片的地址
              img = Image.open(img_item_path)  # 打开图片
              label = self.label_dir
              return img, label
      
          # 数量
          def __len__(self):
              return len(self.img_path)
      
      
      root_dir = "Dataset/ReadData/train"  # 图片相对路径
      ants_label_dir = "ants"
      bees_label_dir = "bees"
      ants_dataset = MyData(root_dir, ants_label_dir)
      bees_dataset = MyData(root_dir, bees_label_dir)
      train_dataset = ants_dataset + bees_dataset  # 整个训练数据集
      # 图片展示
      img1, label1 = train_dataset[123]  
      img1.show()  # 展示蚂蚁图片
      img2, label1 = train_dataset[124]
      img2.show()  # 展示蜜蜂图片
    • 分别将蚂蚁和蜜蜂的图片提取并展示出来。 
      • PyTorch入门教学——加载数据(Dataset)_第6张图片 PyTorch入门教学——加载数据(Dataset)_第7张图片

你可能感兴趣的:(PyTorch,pytorch,人工智能,python)