PyTorch:数据读取机制DataLoader

先明确几个常见的名词含义:batch、epoch、iteration
batch:通常我们将一个数据集分割成若干个小的样本集,然后一小部分一小部分喂给神经网络进行迭代,每一小部分样本称为一个batch。
epoch:将训练集中全部数据在模型中进行过一次完整的训练(包括一次正向传播和一次反向传播),成为一个epoch。
iteration:使用一个batch对模型的参数进行一次更新的过程,成为一个iteration(一次迭代)。

首先导入包。下面介绍的包都位于torch.util.data中。

下面我们手动创建一个数据集,包括1000个样本,每个样本有两个特征。因此样本集是一个1000×2的Tensor,标签集是一个1000×1的Tensor。每个样本的两个特征随机得到,标签符合再加上一个符合正太分布的误差。

import torch
import torch.utils.data as Data

# 生成数据集
feature_num = 2   # 两个输入x1、x2
sample_num = 1000 # 1000个样本
true_w = [2, -5]
true_b = 3
samples = torch.randn(sample_num, feature_num, dtype=torch.float32) # 生成1000*2的张量,作为输入的样本
labels = true_w[0] * samples[:,0] + true_w[1] * samples[:,1] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)

一、TensorDataset()函数:将数据包装成数据集

对于给定的数据,包括数据的样本和标签,将二者包装成一个Dataset。

注意:1. 输入的参数必须都是Tensor。

  1. 两个参数的第一个维度必须是一致的,即样本和标签可以按行一一对应并组合起来。

我们将上面的数据集生成数据集。

dataset = Data.TensorDataset(samples, labels)

二、DataLoader()函数:加载数据集

加载dataset,根据设置的参数返回迭代器。

函数为

torch.utils.data.DataLoader(dataset, batch_size=1,
        shuffle=False, sampler=None,
        batch_sampler=None, num_workers=0,
        collate_fn=None, pin_memory=False,
        drop_last=False, timeout=0,
        worker_init_fn=None, multiprocessing_context=None,
        generator=None, *, 
        prefetch_factor=2, persistent_workers=False)

常用参数的含义如下。

  1. dataset:Dataset类型,加载数据的数据集。
  2. batch_size:int类型。每个batch中包含的样本数量,默认值为1。
  3. shuffle:bool类型。设置为True时,会在每个epoch都打乱数据。
  4. samlper:Sampler类型。用来指定从数据集中提取样本的策略,如果指定sampler,则shffle必须设置成False。
    几个常见的取值如下。
  • torch.utils.data.sampler.SequentialSampler(dataset):样本元素按顺序采样,始终以相同的顺序。
  • torch.utils.data.sampler.RandomSampler(dataset):样本元素随机采样,没有替换。
  • torch.utils.data.sampler.SubsetRandomSampler(indices):样本元素从指定的索引列表中随机抽取,没有替换。
  1. num_workers:int类型。定义使用多少个子进程加载数据,0表示数据在主进程中加载。
  2. pin_memory:bool类型。如果为True,则在返回Tensor前将它们拷贝到CUDA的pinned memory中。
  3. drop_last:bool类型。当数据集大小不能被batch_size整除时,若设置为True,会删除最后一个不完整的batch。如果设置为False则会保留这个batch。
  4. timeout:数值类型,必须大于等于0。用来设置读取数据的超时时间,如果为正数,则超过这个时间还未读取到数据会报错。

将上文中生成的数据集进行随机的小批量读取。

batch_size = 10
data_iter = Data.DataLoader(dataset, batch_size, shuffle=False, 
        sampler=torch.utils.data.sampler.RandomSampler(dataset))

我们可以读取并打印每一个batch的样本。这里每次循环都会调用一次data_iter,data_iter都会自动向后迭代一个batch。

for X, y in data_iter:
  print(X, "\n", y)

你可能感兴趣的:(PyTorch:数据读取机制DataLoader)