torch.utils.data.DataLoader学习

官方文档

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

数据加载器。组合数据集和采样器,并提供给定数据集的可迭代对象。
该同时支持地图风格和迭代式的数据集与单或多进程加载,定制加载顺序和可选的自动配料(对照)和存储器的钉扎。DataLoader

torch.utils.data.DataLoader学习_第1张图片

参数

  • dataset(数据集) -数据从其中加载所述数据集。
  • batch_size ( int , optional ) – 每批要加载多少样本(默认值:)1。
  • shuffle(布尔,可选) -设置为True有数据在每个时间段改组(默认:False)。
  • sampler ( Sampler or Iterable , optional ) – 定义从数据集中抽取样本的策略。可以是任何Iterable与__len__ 实施。如果指定,则shuffle不得指定。
  • batch_sampler ( Sampler or Iterable , optional ) – 类似sampler,但一次返回一批索引。互斥有 batch_size,shuffle,sampler,和drop_last。
  • num_workers ( int , optional ) – 用于数据 加载的子进程数。0意味着数据将在主进程中加载​​。(默认值:0)
  • collat​​e_fn ( callable , optional ) – 合并一个样本列表以形成一个小批量的 Tensor(s)。在使用地图样式数据集的批量加载时使用。
  • pin_memory ( bool , optional ) – 如果True,数据加载器将在返回之前将张量复制到 CUDA 固定内存中。如果您的数据元素是自定义类型,或者您collate_fn返回的批次是自定义类型,请参见下面的示例。
  • drop_last ( bool , optional ) –True如果数据集大小不能被批处理大小整除,则设置为删除最后一个不完整的批处理。如果False并且数据集的大小不能被批大小整除,那么最后一批将更小。(默认值:False)
  • timeout ( numeric , optional ) – 如果为正,则为从工作人员收集批次的超时值。应该总是非负的。(默认值:0)
  • worker_init_fn ( callable , optional ) – 如果不是None,这将在每个工作子进程上调用,在播种之后和数据加载之前,将工作人员 id(一个 int in )作为输入。(默认值:)[0, num_workers - 1]None
  • generator ( torch .Generator , optional ) – 如果没有,RandomSampler 将使用这个 RNG 来生成随机索引和多处理来为 worker生成 base_seed。(默认值:)NoneNone
  • prefetch_factor ( int , optional , keyword-only arg ) – 每个工作人员提前加载的样本数。2意味着将在所有工作人员中预取总共 2 * num_workers 个样本。(默认值:2)
  • persistent_workers ( bool , optional ) – 如果True,数据加载器在数据集被消费一次后不会关闭工作进程。这允许保持工作人员数据集实例处于活动状态。(默认值:False)

简单示例

"""
    批训练,把数据变成一小批一小批数据进行训练。
    DataLoader就是用来包装所使用的数据,每次抛出一批数据
"""
import torch

BATCH_SIZE = 3

x = torch.linspace(1, 9, 9)
y = torch.linspace(9, 1, 9)
# 把数据放在数据库中
torch_dataset = torch.utils.data.TensorDataset(x, y)
loader = torch.utils.data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
)


def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training


            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))


if __name__ == '__main__':
    show_batch()

你可能感兴趣的:(pytorch学习,python)