【Pytorch学习】-- 读取数据 -- torch.utils.data -- Dataset

学习视频:https://www.bilibili.com/video/BV1hE411t7RN?p=1,内含环境搭建

Pytorch有两个读取数据的方式:

  1. 使用Dataset
  2. 使用DataLoader

本文先介绍第一种——Dataset

Dataset与DataLoader区别

  1. Dataset:提供一种方法,去获取数据及其对应的label值
  2. DataLoader:提供一种方法,可以以特定的形式打包数据

数据集

接下来使用的数据集下载地址:https://download.pytorch.org/tutorial/hymenoptera_data.zip
文件结构:(本人将文件夹重新命名为"dataset")

dataset
├── train
│   ├── ants
│   └── bees
└── val
    ├── ants
    └── bees

使用torch.utils.data下Dataset读取数据

在处理数据前,首先要做的就是读取数据,torch提供了对应读取数据方法来适配其他torch的处理数据方法。 代码如下:

from torch.utils.data import Dataset  # 导入Dataset后可以使用“help(Dataset)查看官方文档”
from PIL import Image                 # 借助PIL库导入数据图片
import os                             # 借助os库来用路径读入数据

class Mydata(Dataset):				  							# 根据官方文档,自己创建的类必须继承Dataset
	def __init__(self,root_dir,label_dir):						# 初始化操作,传入图片所在的根目录路径(root_dir)和label的路径(label_dir)获得一个路径列表(img_path)
        self.root_dir = root_dir			
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)	# 用join把路径拼接一起可以避免一些因“/”引发的错误
        self.img_path = os.listdir(self.path) 					# 将该路径下的所有文件变成一个列表
        		
	def __getitem__(self,idx)		  										# 使用index(简写为idx)获取某个数据
        img_name = self.img_path[idx]										# img_path列表里每个元素就是对应图片文件名
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)	# 获得对应图片路径
        img = Image.open(img_item_path)										# 使用PIL库下Image工具,打开对应路径图片
        label = self.label_dir												# 本数据集label就是文件名,如“ants”(虽然命名为dir看似路径,实则视作字符串会更容易理解)
        return img,label													# 返回对应图片和图片的label


# 调用
root_dir = "/content/drive/MyDrive/Pytorch学习/dataset/train"
ants_label_dir = "ants"
ants_dataset = Mydata(root_dir,label_dir)
ants_dataset[0]

结果:返回的一个元组,元组中有两个数据,一个是集合<…>部分,一个是字符串"ants"

(,
 'ants')

因此可以这样赋值,即可显示图片

img,label = ants_dataset[0]
img.show()

小技巧:

train_dataset = ants_dataset + bees_dataset # 将两个数据集拼接起来

你可能感兴趣的:(pytorch,深度学习,机器学习)