本文需要的预备知识:Python中的迭代器与生成器。
很多时候,预处理数据集的代码可能会非常混乱且变的难以维护。理想情况下,我们希望我们的数据集代码与模型的训练代码分离,以实现更好的可读性和模块化。
Pytorch中的 torch.utils.data
提供了两个抽象类:Dataset
和 Dataloader
。Dataset
允许你自定义自己的数据集,用来存储样本及其对应的标签。而 Dataloader
则是在 Dataset
的基础上将其包装为一个可迭代对象,以便我们更方便地(小批量)访问数据集。
import torch
from torch.utils.data import Dataset, Dataloader
一 些 必 备 概 念 : \textcolor{blue}{一些必备概念:} 一些必备概念:
Data Size
:整个数据集的大小;Batch Size
:在训练过程中,我们不可能把所有样本一次性投喂给神经网络,只能分批次投喂。每个小批量的样本个数就是 Batch Size;Iteration
:将一个 Batch 投喂给神经网络称为一次 Iteration;Epoch
:将所有的样本(即所有 Batch)投喂给神经网络后称为一个 Epoch。例如,设 Data Size 为 100 100 100,Batch Size 为 20 20 20,则所有样本需要分 5 5 5 次才能全部投喂给神经网络。每个 Batch 有 20 20 20 个样本,对应了一个 Iteration。每个 Epoch 有 5 5 5 个 Iteration。
需要注意的是,Batch Size 不一定能整除 Data Size。例如,当 Data Size 为 10 10 10,Batch Size 为 3 3 3 时,此时一共会有 4 4 4 个 Batch。前三个 Batch 的样本个数均为 3 3 3,最后一个 Batch 的样本个数为 1 1 1。每个 Epoch 需要 4 4 4 次 Iteration。
通过下图我们可以更直观的了解 Dataset
和 Dataloader
的工作方式:
先看一遍源码:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
可以看出,Dataset
是一个抽象类,我们自己编写的数据集类必须继承 Dataset
,且需重新改写 __getitem__
和 __len__
方法。
__getitem__
:传入指定的索引 index
后,该方法能够根据索引返回对应的单个样本及其对应的标签(以元组形式)。__len__
:返回整个数据集的大小,即前面所说的 Data Size
。根据源码,若我们自定义的类在继承 Dataset
时没有改写 __getitem__
,则程序会抛出 NotImplementedError
的异常。此外,因为 Dataset
类中提供了 __add__
方法,所以继承之后我们的数据集也会拥有此方法,从而合并数据集只需使用 +
运算即可。
一般而言,我们自定义的数据集的框架如下:
class MyDataset(Dataset):
def __init__(self):
# 初始化数据集的存储路径
# 载入数据集(转化为tensor格式)
# ...
def __getitem__(self, index):
# 返回单个样本及其标签
pass
def __len__(self):
# 返回整个数据集的大小
pass
只看框架可能会觉得有些抽象,接下来我们通过一个具体的例子来进一步理解。
假设当前工作目录下有一个 data.txt
,其内容如下:
1 -14 -15
1 -1 -15
1 -11 -14
1 0 -2
0 -4 2
1 7 -2
1 -7 -17
0 9 12
0 5 -14
1 -13 13
其中每一行都是一个样例。每行中的后两个数字为样本的特征,第一个数字为样本对应的标签。可以看出,我们一共有 10 10 10 个样本,它们均位于二维欧式空间中,且问题为二分类问题。
于是数据集的框架可以这样写:
class MyDataset(Dataset):
def __init__(self, path):
self.data = np.loadtxt(path)
self._X = torch.from_numpy(self.data[:, 1:])
self._y = torch.from_numpy(self.data[:, 0])
def __getitem__(self, index):
return self._X[index], self._y[index]
def __len__(self):
return len(self._X)
使用时只需创建实例即可
path = './data.txt'
data = MyDataset(path)
我们可以调用各个方法来观察一下
len(data)
# 10
data[1]
# (tensor([ -1., -15.], dtype=torch.float64), tensor(1., dtype=torch.float64))
事实上,data
是一个可迭代对象,我们可以直接使用 for 循环来输出整个数据集:
for feature, label in data:
print(feature, label)
# tensor([-14., -15.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([ -1., -15.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([-11., -14.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([ 0., -2.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([-4., 2.], dtype=torch.float64) tensor(0., dtype=torch.float64)
# tensor([ 7., -2.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([ -7., -17.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([ 9., 12.], dtype=torch.float64) tensor(0., dtype=torch.float64)
# tensor([ 5., -14.], dtype=torch.float64) tensor(0., dtype=torch.float64)
# tensor([-13., 13.], dtype=torch.float64) tensor(1., dtype=torch.float64)
前面我们提到过,绝大多数时候我们需要以 batch 的形式访问数据集。Dataloader
这个接口提供了这样的功能,它能够基于我们自定义的数据集将其转换成一个可迭代对象以便我们批量访问。
Dataloader
的常用参数列在下方,更多请参考官方文档
Dataloader( dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False) \begin{aligned} \text{Dataloader(}\text{dataset,\; batch\_size=1,\;shuffle=False,\;} \text{num\_workers=0,\;drop\_last=False)} \end{aligned} Dataloader(dataset,batch_size=1,shuffle=False,num_workers=0,drop_last=False)
dataset
:自定义数据集类的实例。batch_size
:每个 batch 的大小。shuffle
:用来控制是否在每个 epoch 开始时打乱数据集。num_workers
:决定加载数据集时使用多少子进程。默认只使用主进程加载。drop_last
:当 batch size 不能整除 data size 时,最后一个 batch 会变得不完整。该参数决定是否扔掉最后一个不完整的 batch。我们继续使用 1.3 中的例子:
dataloader = DataLoader(data, batch_size=3, shuffle=False, drop_last=False)
该代码将创建一个可迭代对象,我们将其列表化来观察一下:
list(dataloader)
# [[tensor([[-14., -15.],
# [ -1., -15.],
# [-11., -14.]], dtype=torch.float64),
# tensor([1., 1., 1.], dtype=torch.float64)],
# [tensor([[ 0., -2.],
# [-4., 2.],
# [ 7., -2.]], dtype=torch.float64),
# tensor([1., 0., 1.], dtype=torch.float64)],
# [tensor([[ -7., -17.],
# [ 9., 12.],
# [ 5., -14.]], dtype=torch.float64),
# tensor([1., 0., 0.], dtype=torch.float64)],
# [tensor([[-13., 13.]], dtype=torch.float64),
# tensor([1.], dtype=torch.float64)]]
可以看出,列表化后,每一个 batch 均以列表的形式存储。这说明我们可以通过 for 循环来遍历所有的 batch,具体做法如下:
for inputs, labels in dataloader:
print(inputs, labels)
# tensor([[-14., -15.],
# [ -1., -15.],
# [-11., -14.]], dtype=torch.float64) tensor([1., 1., 1.], dtype=torch.float64)
# tensor([[ 0., -2.],
# [-4., 2.],
# [ 7., -2.]], dtype=torch.float64) tensor([1., 0., 1.], dtype=torch.float64)
# tensor([[ -7., -17.],
# [ 9., 12.],
# [ 5., -14.]], dtype=torch.float64) tensor([1., 0., 0.], dtype=torch.float64)
# tensor([[-13., 13.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
若要进行打乱并丢掉最后一个不完整的 batch,可以这样设置参数
dataloader = DataLoader(data, batch_size=3, shuffle=True, drop_last=True)
for inputs, labels in dataloader:
print(inputs, labels)
# tensor([[ 5., -14.],
# [ -4., 2.],
# [ 7., -2.]], dtype=torch.float64) tensor([0., 0., 1.], dtype=torch.float64)
# tensor([[-11., -14.],
# [-13., 13.],
# [ -1., -15.]], dtype=torch.float64) tensor([1., 1., 1.], dtype=torch.float64)
# tensor([[ 9., 12.],
# [-14., -15.],
# [ 0., -2.]], dtype=torch.float64) tensor([0., 1., 1.], dtype=torch.float64)
有些时候,我们需要记录每个 batch 的索引(即 iteration),则需要用到 enumerate
函数(这里为了方便展示将 batch_size
设为了1):
dataloader = DataLoader(data, batch_size=1, shuffle=True, drop_last=True)
for batch_idx, (inputs, labels) in enumerate(dataloader):
print(batch_idx, end=' ')
print(inputs, labels)
# 0 tensor([[-4., 2.]], dtype=torch.float64) tensor([0.], dtype=torch.float64)
# 1 tensor([[ -1., -15.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 2 tensor([[ 0., -2.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 3 tensor([[ 7., -2.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 4 tensor([[ 9., 12.]], dtype=torch.float64) tensor([0.], dtype=torch.float64)
# 5 tensor([[ 5., -14.]], dtype=torch.float64) tensor([0.], dtype=torch.float64)
# 6 tensor([[-11., -14.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 7 tensor([[-14., -15.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 8 tensor([[ -7., -17.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 9 tensor([[-13., 13.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)