Pytorch 源码分析 torch.utils.data.DataLoader

今天来分析一下,在看代码中遇到的问题,先看源码torch.utils.data.DataLoader。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False)

  这是一个数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。

参数:

  • datasetDataset) – 要加载数据的数据集。
  • batch_sizeint, 可选) – 每一批要加载多少数据(默认:1)。
  • shufflebool, 可选) – 如果每一个epoch内要打乱数据,就设置为True(默认:False)。
  • samplerSampler, 可选)– 定义了从数据集采数据的策略。如果这一选项指定了,shuffle必须是False。
  • batch_samplerSampler, 可选)– 类似于sampler,但是每次返回一批索引。和batch_sizeshufflesamplerdrop_last互相冲突。
  • num_workersint, 可选) – 加载数据的子进程数量。0表示主进程加载数据(默认:0)。
  • collate_fn可调用 , 可选)– 归并样例列表来组成小批。
  • pin_memorybool, 可选)– 如果设置为True,数据加载器会在返回前将张量拷贝到CUDA锁页内存。
  • drop_lastbool, 可选)– 如果数据集的大小不能不能被批大小整除,该选项设为True后不会把最后的残缺批作为输入;如果设置为False,最后一个批将会稍微小一点。(默认:False
  • timeout数值 , 可选) – 如果是正数,即为收集一个批数据的时间限制。必须非负。(默认:0
  • worker_init_fn可调用 , 可选)– 如果不是None,每个worker子进程都会使用worker id(在[0, num_workers - 1]内的整数)进行调用作为输入,这一过程发生在设置种子之后、加载数据之前。(默认:None

注意:

默认地,每个worker都会有各自的PyTorch种子,设置方法是base_seed + worker_id,其中base_seed是主进程通过随机数生成器生成的long型数。而其它库(如NumPy)的种子可能由初始worker复制得到, 使得每一个worker返回相同的种子。(见FAQ中的My data loader workers return identical random numbers部分。)你可以用torch.initial_seed()查看worker_init_fn中每个worker的PyTorch种子,也可以在加载数据之前设置其他种子。

 

而在我看代码的时候遇到了这个函数collate_fn(可调用 , 可选)– 归并样例列表来组成小批。下面来举个例子:

from torch.utils.data.dataloader import default_collate


    a = [(1, 2), (3, 4)]
    print(a)
    a = default_collate(a)
    print(a)

运行结果如下:

[(1, 2), (3, 4)]
[tensor([1, 3]), tensor([2, 4])]

由此可见在我们使用torch.utils.data.DataLoder()数据后,它还有另外一种功能假如当数据中有损坏的文件,我们则需要剔除,剔除试用的是filter(function,iteration)来处理,返回一个迭代器对象,我们可以通过list转换为列表的形式。但是比如送到神经网络中,我们还是需要是batch的形式,原来的DataLoader处理后的图像就是tensor的形式。

如有不足,欢迎交流。

你可能感兴趣的:(Pytorch)