pytorch数据加载-DataLoader

本篇主要介绍torch.utils.data.DataLoader的作用

	DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
	           batch_sampler=None, num_workers=0, collate_fn=None,
	           pin_memory=False, drop_last=False, timeout=0,
	           worker_init_fn=None, *, prefetch_factor=2,
	           persistent_workers=False)

参数说明:
        dataset(Dataset):以迭代方式加载数据的数据集,具体加载形式需要重载__getitem__函数
        batch_size (int, optional): 每批要加载多少个样本
            (默认值:``1``)。
        shuffle (bool, optional): 设置为 True 让数据重新洗牌
            在每个时期(默认值:``False``)。
        sampler (Sampler or Iterable, optional): 定义采样策略
            数据集中的样本。可以是任何带有 __len__ 的 ``Iterable``
            实施的。如果指定,则不能指定 :attr:`shuffle`。
        batch_sampler(Sampler 或 Iterable,可选):类似于 :attr:`sampler`,但是
            一次返回一批索引。
        num_workers (int, optional): 有多少子进程用于数据
            加载。 ``0`` 表示数据将在主进程中加载​​,容易造成等待加载数据的时间过长
            (默认值:``0``)
        collat​​e_fn(可调用,可选):合并样本列表以形成
            小批量张量。使用批量加载时使用,如何处理batch数据,并返回什么形式的数据
            
        pin_memory (bool, optional): 如果为True,数据加载器将复制张量
            在返回它们之前进入 CUDA 固定内存。如果您的数据元素
            是自定义类型,或者您的 collat​​e_fn 返回一个自定义类型的批次,
            请参见下面的示例。
        drop_last (bool, optional): 设置为 ``True`` 以丢弃最后一个不完整的批次,
            如果数据集大小不能被批量大小整除。如果“假”和
            数据集的大小不能被批大小整除,那么最后一批
            会更小。 (默认:“假”)
        timeout(numeric, optional):如果为正,则收集批次的超时值
            来自worker。应始终为非负数。 (默认值:``0``)
        worker_init_fn (callable, optional): 如果不是``None``,这将在每个
            worker id 的子进程(一个 int in ``[0, num_workers - 1]``)
            输入,在播种之后和数据加载之前。 (默认:“无”)
        prefetch_factor (int, optional, keyword-only arg): 加载的样本数
            由每个worker提前加载。 ``2`` 表示总共有2 * num_workers 样本预取。 (默认:``2``)
        persistent_workers (bool, optional): 如果 ``True``,数据加载器不会关闭
            worker在数据集被使用一次后进行处理。这允许
            保持worker“数据集”实例处于活动状态。 (默认:“False”)

Demo

from torch.utils.data import DataLoader
import numpy as np
import time
if __name__=="__main__":
    dataset=np.arange(10000)
    data=DataLoader(dataset=dataset,batch_size=100,  shuffle=False,num_workers=2,prefetch_factor=3)#
    begin=time.time()
    for i,d in enumerate(data):
        end=time.time()
        print('time:%f ms'%((end-begin)*1000))
        begin=time.time()

'''
num_workers:越大,初始时间越长
prefetch_factor:

num_workers=10,prefetch_factor=3
time:14173.185825 ms #初始化
time:0.000000 ms
time:0.000000 ms
time:0.998497 ms
time:0.000000 ms
time:0.000000 ms
time:0.996590 ms
time:0.000000 ms
time:0.000000 ms
time:2.991915 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.996590 ms


num_workers=5,prefetch_factor=3
time:7062.668800 ms
time:0.999212 ms
time:0.000000 ms
time:0.996351 ms
time:6.023645 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.965357 ms
time:0.000000 ms
time:0.000000 ms
time:0.989676 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.000000 ms
time:0.996590 ms

'''

你可能感兴趣的:(Pytorch,torch.data,pytorch)