深入浅出 Dataset 与 DataLoader

文章目录

  • Dataset & DataLoader
    • 一、自定义Dataset
    • 二、使用 DataLoaders 为训练准备数据
    • 三、迭代数据

Dataset & DataLoader

1、官方解释(Google翻译):
处理数据样本的代码可能会变得混乱且难以维护;理想情况下,我们希望我们的数据集代码与我们的模型训练代码分离,以获得更好的可读性和模块化。
PyTorch 提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset 允许我们使用预加载的数据集以及我们自己的数据。 Dataset存储样本及其对应的标签,并DataLoader在 周围包裹一个可迭代对象Dataset,以便轻松访问样本。
2、Dataset
是所有开发人员训练、测试使用的所有数据集的一个模板。
Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。
DataLoader定义了按batch加载数据集的方法,它是一个实现了__iter__方法的可迭代对象,每次迭代输出一个batch的数据。
3、DataLoader
DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。
在绝大部分情况下,我们只需实现Dataset的 __len__方法__getitem__方法 ,就可以轻松构建自己的数据集,并用默认数据管道进行加载。

一、自定义Dataset

自定义 Dataset 类需继承 pytorch官方的DataSet类 还必须实现三个函数:__init____len____getitem__
init:初始化(一般需要传入 数据集文件路径,将文件保存到哪个路径预处理函数)
len:返回数据集的大小
getitem:根据索引,返回样本的特征和标签。

import os.path

import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image


class MyImageDataset(Dataset):
    def __init__(self, annotations_file, data_dir, transform=None, target_transform=None):
        # annotations_file:文件路径
        # data_dir: 将文件保存到哪个路径
        self.data_label = pd.read_csv(annotations_file)
        self.data_dir = data_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        # 返回数据集总的大小
        return len(self.data_label)

    def __getitem__(self, item):
        data_name = os.path.join(self.data_dir, self.data_label.iloc[item, 0])
        image = read_image(data_name)
        # 对特征进行预处理
        label = self.data_label.iloc[item, 1]
        if self.transform:
            image = self.transform(image)
        # 对标签进行预处理
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

其实我们只需要修改的是annotations_file, data_dir, transform(特征预处理), target_transform(标签预处理) 这四个参数。
Dataset每次只处理一个样本,返回的是一个特征该特征所对应的标签

二、使用 DataLoaders 为训练准备数据

检索我们数据集的Dataset特征并一次标记一个样本。在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个 epoch (每次迭代多少次)重新洗牌以减少模型过度拟合,并使用 Pythonmultiprocessing加速数据检索。

batch_size:一次训练所选取的样本数
shuffle=True: 每个训练周期后对数据进行随机排列

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

三、迭代数据

我们已将该数据集加载到 中,DataLoader并且可以根据需要遍历数据集。下面的每次迭代都会返回一批train_features和train_labels(分别包含batch_size=64特征和标签)。
iter()方法 得到一个迭代器。
next() 方法 依次获得特征和标签。

train_features, train_labels = next(iter(train_dataloader))

你可能感兴趣的:(人工智能+大数据,算法,机器学习,深度学习,人工智能)