def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
其中几个常用的参数:
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
可以看出,当dataset类型是map style时, s h u f f l e shuffle shuffle其实就是改变sampler的取值。
在看一段初始化代码:
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
# 略去类的初始化
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
就是按照batch_size从sampler中读取索引,并形成生成器返回。
以上可以看出 b a t c h s a m p l e r batch_sampler batchsampler和 s a m p l e r sampler sampler,batch_size和drop_last之间的关系。
如果batch_sampler没有定义的话,且batch_size有定义,会根据:
sampler,batch_szie,drop_last生成一个batch_sampler。
自带的注释中对batch_sampler有一句话:
Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
意思就是batch_sampler与这些参数冲突, 即如果你定义了batch_sampler,其他参数都不需要有了。
每个batch都是由迭代器产生的。
# DataLoader中iter的部分
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
# 再看调用的另一个类
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index()
data = self._dataset_fetcher.fetch(index)
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
对上面的代码进行一一略读,初始化略过。
def _next_index(self):
return next(self._sampler_iter)
///
self._sampler_iter = iter(self._index_sampler)
# 以上又用了一个迭代器生成索引
self._index_sampler = loader._index_sampler
///
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
可以看出, _next_index其实就是batch_sample或是sampler用迭代器生成了一遍。
而sampler返回的就dataset中对应的索引值:
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
# 按map-style往下看
class _MapDatasetFetcher(_BaseDatasetFetcher):
# 略过初始化
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
# 关键
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
可以看到DataLoader中迭代器生成的data,就是根据ssmpler或者batch_sampler生成的索引,从dataset中取值,然后经过collate_fn处理。
这里还要提一下auto_collation参数。
def _auto_collation(self):
return self.batch_sampler is not None
其实就是判断batch_sampler是否为None的情况,而根据batch_sampler的定义,只有初始化参数batch_size,和batch_sampler都为None时,才为False。
这时从fetch函数可以看出,就是每次取一个值,_next_index()取的也是sampler, 此时相当与batch_size等于1。
由此,明白了整个大致过程,我们都可以i对sampler进行定义了。来获得我们想要的batch。
sampler其实都是用来定义取batch的方法的一个函数或者类,返回的是一个迭代器。
我们可以看一下自带的RandomSampler类中最重要的 i t e r iter iter函数。
def __iter__(self):
n = len(self.data_source)
# dataset的长度, 按顺序索引
if self.replacement:# 对应的replace参数
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
可以看出,其实就是生成索引,然后随机的取值,然后再迭代。
其实还有一些细节需要注意理解:
在dataset预处理中,曾遇到这样一个问题,interrupted by signal 9: SIGKILL.
class ImageDataset(Dataset):
"""Image Person ReID Dataset"""
def __init__(self, dataset, transform=None):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img_path, pid, camid = self.dataset[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
return img, pid, camid, img_path
`
之前说过, DataLoader的参数dataset类型可以是列表,字典等。
可以用索引去读取值的类型。此处,则是一个类,类中定义了__getitem__函数,使其能够用index去取值。
这个类初始化输入的dataset其实都是图片地址和id参数,
而当有index来访问时,再去读取一个batch的图片,然后再对图片进行ransform。可以节省内存。
collate_fn 函数,可以从上面的fetch部分中看到, 也是对读取到的batch进行处理的一个对象,所以,预处理实际上也可以放在collate_fn中。``
#总结
慢慢的理解,争取多搞几篇文章,争取将该方法全部都将其搞定。
慢慢的自己理解透彻。理解彻底都行啦的样子与打算。