def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
其中几个常用的参数
当我们sampler有输入时,shuffle的值就没有意义,
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
当dataset类型是map style时, shuffle其实就是改变sampler的取值
sampler 是用来定义取batch方法的一个函数或者类,返回的是一个迭代器。
我们可以看下自带的RandomSampler类中最重要的iter函数
def __iter__(self):
n = len(self.data_source)
# dataset的长度, 按顺序索引
if self.replacement:# 对应的replace参数
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
可以看出,其实就是生成索引,然后随机的取值, 然后再迭代。
其实还有一些细节需要注意理解:
比如__len__函数,包括DataLoader的len和sample的len, 两者区别, 这部分代码比较简单,可以自行阅读,其实参考着RandomSampler写也不会出现问题。
比如,迭代器和生成器的使用, 以及区别
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
BatchSampler的生成过程
# 略去类的初始化
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
就是按batch_size从sampler中读取索引, 并形成生成器返回。
以上可以看出, batch_sampler和sampler, batch_size, drop_last之间的关系
如果batch_sampler没有定义的话且batch_size有定义, 会根据sampler, batch_size, drop_last生成一个batch_sampler
自带的注释中对batch_sampler有一句话: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
意思就是batch_sampler 与这些参数冲突 ,即 如果你定义了batch_sampler, 其他参数都不需要有
每个batch都是由迭代器产生的
# DataLoader中iter的部分
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
# 再看调用的另一个类
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index()
data = self._dataset_fetcher.fetch(index)
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data