from torch.utlis.data import DataLoader
功能:组合数据集和采样器(规定提取样本的方法),并提供对给定数据集的可迭代对象。
通俗一点,就是把输进来的数据集,按照一个想要的规则(采样器)把数据划分好,同时让它是一个可迭代对象(可以循环提取数据,方便后面程序使用)。
DataLoader(
dataset: torch.utils.data.dataset.Dataset[+T_co],
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None,
sampler: Union[torch.utils.data.sampler.Sampler, Iterable, NoneType] = None,
batch_sampler: Union[torch.utils.data.sampler.Sampler[Sequence], Iterable[Sequence], NoneType] = None,
num_workers: int = 0,
collate_fn: Optional[Callable[[List[~T]], Any]] = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Optional[Callable[[int], NoneType]] = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False,
pin_memory_device: str = '',
)
# 在jupyter的cell中输入`DataLoader??` 即可看到其源码
dataset
:要载入的数据集
batch_size
:批大小,每个批中的样本数
shuffle
:是否载入数据集时是否要随机选取(打乱顺序),True
为打乱顺序,False
为不打乱。布尔型,只能取None
、True
、False
。
samper
:定义从数据集中提取样本的策略。需要是可迭代的。如果自定义了它,那shuffle就得是False
,(默认为None)。源码中有
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
'shuffle')
batch_sampler
:和samper
类似,但是一次只返回一个批batch
的。如果自定义了batch_samper,那参数batch_size、shuffle、samper、drop_last得是默认值。源码中
if batch_sampler is not None:
# auto_collation with custom batch_sampler
if batch_size != 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive '
'with batch_size, shuffle, sampler, and '
'drop_last')
num_workers
:线程数。用来实现并行化加载数据。
collate_fn
:将一个list的sample组成一个mini-batch。可以自己定义函数来实现想实现的功能。
一文弄懂Pytorch的DataLoader, DataSet, Samper之间的关系:https://zhuanlan.zhihu.com/p/76893455
sampler
生成一系列(整个数据集)的索引index,而batch_samper
则将这一些列inex按照batch_size分组,每组batch_size个index。自定义采样器需要重新定义方法__iter__(self)
,它的返回需要可迭代。 对于Dataset也可以通过自定义实现对数据的加载,不过定义时需要重新定义方法__getitem__(self)
,就是通过它来规定选取数据的方式。collate_fn参数使用详解: https://zhuanlan.zhihu.com/p/361830892
drop_last
:一般数据集不会是批大小的整数倍,所以最后一批样本数可能会小于批大小。所以这个参数为True就舍弃这些样本,False则保留。默认False。
timeout
:如果为正,则为从一个线程worker中收集一批样本的超时值(等待的时间),超过这个超时值就不收集这个样本了。一般都非负,默认为0。
worker_init_fn
:每个线程worker初始化函数。默认None
pin_memory
:返回数据之前,将复制Tensor(数据)到device/cuda固定内存中(cuda pinned memory)。可以提高数据从cpu到gpu传输效率,即加速。
generator
:如果非空,使用RandomSampler采样去生成随机indexs,并且多线程的生成base_seed
prefetch_factor
:每个线程提前加载的批数。默认为2
persistent_workers
:如果为“True”,则数据加载程序在使用数据集一次后不会关闭工作进程。这允许维护工作线程“数据集”实例处于活动状态。默认False
pin_memory_device
:如果为 true,数据加载器会在返回之前将Tensor复制到device固定内存中,然后再返回它们pin_memory。
https://zhuanlan.zhihu.com/p/76893455
https://zhuanlan.zhihu.com/p/361830892