Pytorch之DataLoader参数说明

Pytorch之DataLoader

1. 导入及功能

from torch.utlis.data import DataLoader

功能:组合数据集和采样器(规定提取样本的方法),并提供对给定数据集的可迭代对象。
通俗一点,就是把输进来的数据集,按照一个想要的规则(采样器)把数据划分好,同时让它是一个可迭代对象(可以循环提取数据,方便后面程序使用)。

2. 全部参数

DataLoader(
    dataset: torch.utils.data.dataset.Dataset[+T_co],
    batch_size: Optional[int] = 1,
    shuffle: Optional[bool] = None,
    sampler: Union[torch.utils.data.sampler.Sampler, Iterable, NoneType] = None,
    batch_sampler: Union[torch.utils.data.sampler.Sampler[Sequence], Iterable[Sequence], NoneType] = None,
    num_workers: int = 0,
    collate_fn: Optional[Callable[[List[~T]], Any]] = None,
    pin_memory: bool = False,
    drop_last: bool = False,
    timeout: float = 0,
    worker_init_fn: Optional[Callable[[int], NoneType]] = None,
    multiprocessing_context=None,
    generator=None,
    *,
    prefetch_factor: int = 2,
    persistent_workers: bool = False,
    pin_memory_device: str = '',
)
# 在jupyter的cell中输入`DataLoader??` 即可看到其源码


3. 参数说明

dataset:要载入的数据集
batch_size:批大小,每个批中的样本数
shuffle:是否载入数据集时是否要随机选取(打乱顺序),True为打乱顺序,False为不打乱。布尔型,只能取NoneTrueFalse

samper:定义从数据集中提取样本的策略。需要是可迭代的。如果自定义了它,那shuffle就得是False,(默认为None)。源码中有

if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

batch_sampler:和samper类似,但是一次只返回一个批batch的。如果自定义了batch_samper,那参数batch_size、shuffle、samper、drop_last得是默认值。源码中

if batch_sampler is not None:
            # auto_collation with custom batch_sampler
            if batch_size != 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
  • Pytorch已经实现的采样器有:SequentialSampler(shuffle设为False时就用的这个)、RandomSampler(shuffle设为True时就用的这个)、WeightedSampler、SubsetRandomSampler

num_workers:线程数。用来实现并行化加载数据。
collate_fn:将一个list的sample组成一个mini-batch。可以自己定义函数来实现想实现的功能。

一文弄懂Pytorch的DataLoader, DataSet, Samper之间的关系:https://zhuanlan.zhihu.com/p/76893455

  • 对上面这篇文章中稍作总结。首先,DataLoader, DataSet, Sampler之间的关系是:一个Dataloader中包含数据的索引indices和具体数据data。 那采样器Sampler针对indices操作,sampler生成一系列(整个数据集)的索引index,而batch_samper则将这一些列inex按照batch_size分组,每组batch_size个index。自定义采样器需要重新定义方法__iter__(self),它的返回需要可迭代。 对于Dataset也可以通过自定义实现对数据的加载,不过定义时需要重新定义方法__getitem__(self),就是通过它来规定选取数据的方式。

collate_fn参数使用详解: https://zhuanlan.zhihu.com/p/361830892

drop_last:一般数据集不会是批大小的整数倍,所以最后一批样本数可能会小于批大小。所以这个参数为True就舍弃这些样本,False则保留。默认False。
timeout:如果为正,则为从一个线程worker中收集一批样本的超时值(等待的时间),超过这个超时值就不收集这个样本了。一般都非负,默认为0。
worker_init_fn:每个线程worker初始化函数。默认None
pin_memory:返回数据之前,将复制Tensor(数据)到device/cuda固定内存中(cuda pinned memory)。可以提高数据从cpu到gpu传输效率,即加速
generator:如果非空,使用RandomSampler采样去生成随机indexs,并且多线程的生成base_seed
prefetch_factor:每个线程提前加载的批数。默认为2
persistent_workers:如果为“True”,则数据加载程序在使用数据集一次后不会关闭工作进程。这允许维护工作线程“数据集”实例处于活动状态。默认False
pin_memory_device:如果为 true,数据加载器会在返回之前将Tensor复制到device固定内存中,然后再返回它们pin_memory。

4. 参考文献

https://zhuanlan.zhihu.com/p/76893455
https://zhuanlan.zhihu.com/p/361830892

你可能感兴趣的:(Pytorch,pytorch,深度学习,python)