PyTorch-1.10(十四)--torch.utils.data基本用法

目录

数据加载器

数据集类型

映射类型数据集

迭代类型数据集

​数据加载顺序和采样器​

加载批处理和非批处理数据

自动批次化(默认)

禁用自动批次化

使用collate_fn

单进程和多进程数据加载

单进程数据加载(默认)

多进程数据加载

内存固定

DataLoader综合应用

数据集抽象类

Dataset

IterableDataset

TensorDataset

ConcatDataset

ChainDataset

Subset

采样器抽象类

SequentialSampler

RandomSampler

SubsetRandomSampler

WeightedRandomSampler

BatchSampler

DistributedSampler


数据加载器

torch.utils.data.DataLoader 类是Pytorch数据加载的核心. 它表示数据集上的Python iterable,并支持下面这些功能,这些选项由 DataLoader进行设置:

  • map-style and iterable-style datasets,

  • customizing data loading order,

  • automatic batching,

  • single- and multi-process data loading,

  • automatic memory pinning.

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

以下各节详细描述了这些选项的效果和用法。

数据集类型

​ DataLoader构造函数最重要的参数是dataset,它指示要从中加载数据的dataset对象。PyTorch支持两种不同类型的数据集:​

  • map-style datasets,

  • iterable-style datasets.

映射类型数据集

映射样式的数据集实现了__getitem__()和__len__()方法,并表示从(可能是非整数的)索引/键到数据样本的映射。

使用dataset[idx]访问数据集时,可以从磁盘上的文件夹中读取第idx个样本及其相应的标签。

迭代类型数据集

​iterable样式数据集是IterableDataset子类的实例,该子类实现了__iter__()方法,并表示数据样本上的iterable。这种类型的数据集特别适合于随机读取代价高昂甚至不太可能的情况,以及批量大小取决于获取的数据的情况。

这种数据集(调用iter(dataset))以返回从数据库、远程服务器甚至实时生成的日志读取的数据流。

NOTE

​ 使用具有多进程数据加载的IterableDataset时。在每个工作进程上复制相同的dataset对象,因此必须对副本进行不同的配置,以避免重复数据​. 详情见 IterableDataset

​数据加载顺序和采样器​

​ 对于iterable样式的数据集,数据加载顺序完全由用户定义的iterable控制。这使得区块读取和动态批量大小的实现更加容易(例如,通过每次生成一个批量样本)。

对于 map-style datasets. torch.utils.data.Sampler ​ 类用于指定数据加载中使用的索引/键的顺序。它们表示数据集索引上的可迭代对象。如在随机梯度下降(SGD)的常见情况下,采样器可以随机排列一系列索引,并一次生成每个索引,或者为小批量SGD生成少量索引。

​ 将根据数据加载器的shuffle参数自动构造顺序或无序取样器。或者,用户可以使用sampler参数指定一个自定义的sampler对象,该对象每次生成下一个要获取的索引/键。

​ 一次生成批次索引列表的自定义采样器可以作为batch_sampler参数传递。还可以通过batch_size和drop_last参数启用自动批处理。sampler和batch_sampler都与iterable样式的数据集不兼容,因为此类数据集没有键或索引的概念。

加载批处理和非批处理数据

DataLoader 支持通过参数batch_size, drop_last, batch_sampler自动将各个提取的数据样本整理成批次。

自动批次化(默认)

这是最常见的情况,对应于获取一小批数据并将其整理成批样本,即包含一个维度为批维度(通常是第一个维度)的张量。当batch_size(默认值1)不是None时,数据加载器将生成批处理的样本,而不是单个样本。batch_size和drop_last参数用于指定数据加载器如何获取数据集的批次。对于map样式的数据集,用户也可以指定batch_sampler,它一次生成一个键列表。

NOTE

batch_size和drop_last参数基本上用于从sampler构造batch_sampler。对于map样式的数据集,采样器要么由用户提供,要么基于shuffle参数构造。对于iterable样式的数据集,采样器是一个虚拟的无限采样器。

NOTE

​从具有多线程处理的iterable样式的数据集提取时,drop_last参数会删除每个worker数据集副本的最后一批非完整数据。

从采样器中使用索引获取样本列表后,作为collate_fn参数传递的函数用于将样本列表整理成批。

在这种情况下,从map样式数据集加载大致相当于:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

从iterable样式数据集加载大致相当于:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

自定义collate_fn可用于自定义排序,例如,将序列数据填充到批次的最大长度。

禁用自动批次化

在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者只需加载单个样本。例如,直接加载批处理数据(例如,从数据库进行批量读取或读取连续的内存块)可能更方便,或者批处理大小取决于数据,或者程序设计用于处理单个样本。在这些情况下,最好不要使用自动批处理(其中collate_fn用于整理样本),而是让数据加载器直接返回dataset对象的每个成员。

当batch_size和batch_sampler均为None(batch_sampler的默认值已为None)时,将禁用自动批处理。从数据集中获取的每个样本都将使用作为collate_fn参数传递的函数进行处理。

禁用自动批处理时,默认的collate_fn只是将NumPy数组转换为PyTorch张量,并保持其他所有内容不变。

在这种情况下,从map样式数据集加载大致相当于:

for index in sampler:
    yield collate_fn(dataset[index])

从iterable样式数据集加载大致相当于:

for data in iter(dataset):
    yield collate_fn(data)

使用collate_fn

启用或禁用自动批次化时,collate_fn的使用略有不同。

禁用自动批处理时,将使用每个单独的数据样本调用collate_fn,并从数据加载程序迭代器生成输出。在这种情况下,默认的collate_fn只是转换PyTorch张量中的NumPy数组。

启用自动批处理时,每次调用collate_fn时都会显示数据样本列表。它希望将输入样本整理成一个批,以便从数据加载器迭代器中生成。

单进程和多进程数据加载

​ 在Python进程中,全局解释器锁(GIL)阻止跨线程真正完全并行化Python代码。为了避免在数据加载时阻塞计算代码,PyTorch提供了一个简单的切换来执行多进程数据加载,只需将参数num_workers设置为正整数。

单进程数据加载(默认)

​ 在这种模式下,数据提取是在初始化数据加载器的同一过程中完成的。因此,数据加载可能会阻塞计算。然而,当用于在进程之间共享数据的资源(例如,共享内存、文件描述符)有限时,或者当整个数据集很小并且可以完全加载到内存中时,可以首选此模式。此外,单进程加载通常显示更可读的错误跟踪,因此对于调试很有用。

多进程数据加载

将参数num_workers设置为正整数将启用具有指定数量的加载器工作进程的多进程数据加载。

在多次迭代之后,对于从工作进程访问的父进程中的所有Python对象,加载程序工作进程将消耗与父进程相同的CPU内存量。如果数据集包含大量数据(例如,在数据集构建时加载了一个非常大的文件名列表)和/或使用了大量工作线程(总体内存使用量是工作线程数*父进程大小),则这可能会有问题。最简单的解决方法是将Python对象替换为非引用表示,如Pandas、Numpy或PyArrow对象,详情查看参考手册。

内存固定

​当主机到GPU的拷贝来自固定(页面锁定)内存时,它们的速度要快得多。

​对于数据加载,将pin_memory=True传递给数据加载程序将自动将获取的数据张量放入固定内存中,从而能够更快地将数据传输到支持CUDA的GPU。
默认内存固定逻辑仅识别张量、映射和包含张量的可重用项。默认情况下,如果固定逻辑看到的批是自定义类型(如果有一个collate_fn返回自定义批类型,则会发生这种情况),或者如果批的每个元素都是自定义类型,则固定逻辑将无法识别它们,并且它将返回该批(或这些元素),而不固定内存。要为自定义批处理或数据类型启用内存固定,在自定义类型上定义pin_memory()方法。

请参见下面的示例。

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

DataLoader综合应用

其组合数据集和采样器,并在给定数据集上提供iterable,​ DataLoader支持映射样式和iterable样式的数据集,支持单进程或多进程加载、自定义加载顺序以及可选的自动批处理(排序)和内存固定。参数用法如下:

CLASStorch.utils.data.DataLoader(datasetbatch_size=1shuffle=Falsesampler=Nonebatch_sampler=Nonenum_workers=0collate_fn=Nonepin_memory=Falsedrop_last=Falsetimeout=0worker_init_fn=Nonemultiprocessing_context=Nonegenerator=None*prefetch_factor=2persistent_workers=False)[SOURCE]

参数:

  • dataset (Dataset) – 从中加载数据的数据集.

  • batch_size (intoptional) – 每个批次要加载的样本数(默认值:1)。

  • shuffle (booloptional) – 设置为True可在每个epoch重新排列数据(默认值:False)。

  • sampler (Sampler or Iterableoptional) – 定义从数据集提取样本的策略。

  • batch_sampler (Sampler or Iterableoptional) – 与sampler类似,但一次返回一批索引。与batch_size、shuffle、sampler和drop_last互斥。

  • num_workers (intoptional) –要用于数据加载的子进程数。0表示将在主进程中加载数据。(默认值:0)

  • collate_fn (callableoptional) – 合并样本列表以形成一小批张量。使用从map样式数据集批量加载时使用。

  • pin_memory (booloptional) – 如果为True,数据加载器将在返回张量之前将其复制到CUDA固定内存中。如果数据元素是自定义类型,或者collate_fn返回的批次是自定义类型,请参见使用手册。

  • drop_last (booloptional) – 如果数据集大小不能被批大小整除,则设置为True以删除最后一个不完整的批。如果为False,并且数据集的大小不能被批大小整除,则最后一批将变小。(默认值:False)

  • timeout (numericoptional) –如果为正,则为从workers收集批次的超时值。应始终为非负。(默认值:0)

  • worker_init_fn (callableoptional) – 如果不是None,则在种子设定之后和数据加载之前,将对每个工作子进程调用此函数,并将worker id作为输入。(默认值:无)

  • generator (torch.Generatoroptional) – 如果不是None,RandomSampler将使用此RNG生成随机索引,并通过多处理为工作人员生成base_seed。(默认值:无)

  • prefetch_factor (intoptionalkeyword-only arg) – 每个worker提前装载的数据样本大小。2表示将在所有工人中预取总共2*num_workers样本。(默认值:2)

  • persistent_workers (booloptional) – 如果为True,则数据集使用一次后,数据加载器不会关闭工作进程。这允许保持workers数据集实例处于活动状态。(默认值:False)

数据集抽象类

Dataset

CLASStorch.utils.data.Dataset(*args**kwds)[SOURCE]

​表示数据集的抽象类。所有表示从键到数据样本的映射的数据集都应该对其进行继承。所有子类都应该重写__getitem__()方法,支持获取给定键的数据样本。子类还可以选择性地覆盖__len__(),许多采样器实现和DataLoader的默认选项都会返回数据集的大小。

NOTE

DataLoader 默认情况下,构造生成整数索引的索引采样器。要使其与具有非整数索引/键的map样式数据集一起工作,必须提供自定义采样器。

IterableDataset

CLASStorch.utils.data.IterableDataset(*args**kwds)[SOURCE]

iterable数据集。所有表示数据样本iterable的数据集都应该对其进行继承。当数据来自流时,这种形式的数据集特别有用。所有子类都应覆盖__iter__(),这将返回此数据集中样本的迭代器。

Example 1: splitting workload across all workers in __iter__():

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]
Example 2: splitting workload across all workers using worker_init_fn:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]

TensorDataset

CLASStorch.utils.data.TensorDataset(*tensors)[SOURCE]

数据集包装张量。每个样本将通过沿第一维度索引张量来检索。

*tensors (Tensor)参数是与第一维度大小相同的张量。

ConcatDataset

CLASStorch.utils.data.ConcatDataset(datasets)[SOURCE]

数据集作为多个数据集的串联。此类用于组装不同的现有数据集。

datasets (sequence)参数为要连接的数据集列表

ChainDataset

CLASStorch.utils.data.ChainDataset(datasets)[SOURCE]

用于链接多个 IterableDataset 数据集.

此类可用于组装不同的现有数据集流。链接操作是动态完成的,因此将大规模数据集与此类连接起来将非常有效。

datasets (iterable of IterableDataset) 参数是要链接在一起的数据集

Subset

CLASStorch.utils.data.Subset(datasetindices)[SOURCE]

指定索引处的数据集子集。

参数

  • dataset (Dataset) – 整个数据集

  • indices (sequence) – 为子集选择的全集索引

采样器抽象类

CLASStorch.utils.data.Sampler(data_source)[SOURCE]

所有采样器的基类。每个采样器子类都必须提供一个__iter__()方法,提供一种遍历数据集元素索引的方法,以及一个__len__()方法,该方法返回返回的迭代器的长度。

NOTE

​DataLoader并不严格要求使用__len__()方法,但在涉及DataLoader长度的任何计算中都需要使用该方法。

SequentialSampler

CLASStorch.utils.data.SequentialSampler(data_source)[SOURCE]

按顺序对元素进行采样,始终以相同的顺序进行。

data_source (Dataset) 参数是要从中采样的数据集

RandomSampler

CLASStorch.utils.data.RandomSampler(data_sourcereplacement=Falsenum_samples=Nonegenerator=None)[SOURCE]

随机采样元素。

参数

  • data_source (Dataset) –要从中采样的数据集

  • replacement (bool) – 样本按需抽取,如果为True,则用户可以指定要抽取的样本,如果为False,则从无序数据集中采样。默认值=`` False``

  • num_samples (int) – 要抽取的样本数,默认值=`len(dataset)`。仅当replacement为True时才应指定此参数。

  • generator (Generator) – 取样用生成器.

SubsetRandomSampler

CLASStorch.utils.data.SubsetRandomSampler(indicesgenerator=None)[SOURCE]

从给定的索引列表中随机抽取元素,不进行替换。

参数

  • indices (sequence) – 一系列索引

  • generator (Generator) – 采样中的生成器.

WeightedRandomSampler

CLASStorch.utils.data.WeightedRandomSampler(weightsnum_samplesreplacement=Truegenerator=None)[SOURCE]

使用给定的概率(权重)对[0,…,len(权重)-1]中的元素进行采样。

参数

  • weights (sequence) – 权重序列,不必求和为一

  • num_samples (int) – 要抽取的样本数

  • replacement (bool) –如果为True,则有放回抽取样本。否则,将无放回的抽取样本,这意味着当为一行绘制样本索引时,将无法为该行再次抽取该索引。

  • generator (Generator) – 采样中的生成器

Example

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]

BatchSampler

CLASStorch.utils.data.BatchSampler(samplerbatch_sizedrop_last)[SOURCE]

封装另一个采样器以生成一小批索引。

参数

  • sampler (Sampler or Iterable) – 基础采样器。可以是任何iterable对象

  • batch_size (int) – 小批量的大小。

  • drop_last (bool) – 如果为True,则如果最后一批的大小小于batch_size,采样器将丢弃最后一批

Example

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]

DistributedSampler

CLASStorch.utils.data.distributed.DistributedSampler(datasetnum_replicas=Nonerank=Noneshuffle=Trueseed=0drop_last=False)[SOURCE]

将数据加载限制到数据集子集的采样器。在torch.nn.parallel.DistributedDataParallel有用,​在这种情况下,每个进程都可以将DistributedSampler实例作为DataLoader采样器传递,并加载其专用的原始数据集的子集。

参数

  • dataset – 用于采样的数据集。

  • num_replicas (intoptional) – 参与分布式训练的进程数。默认情况下,从当前分布式组检索world_size。

  • rank (intoptional) – num_replicas中当前进程的排名。默认情况下,从当前分布式组检索等级。

  • shuffle (booloptional) – 如果为True(默认),sampler将重排索引。

  • seed (intoptional) – 如果shuffle=True,则使用andom种子来洗牌采样器。此数字在分布式组中的所有进程中都应相同。默认值:0。

  • drop_last (booloptional) –如果为True,则如果最后一批的大小小于batch_size,采样器将丢弃最后一批,默认False

WARNING

在分布式模式下,在创建DataLoader迭代器之前,需要在每个epoch的开头调用set_epoch()方法,以使重排序在多个epoch之间正常工作。否则,将始终使用相同的顺序。

Example:

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)

你可能感兴趣的:(深度学习框架,pytorch,人工智能,python)