pytorch Dataloader Sampler参数深入理解

DataLoader函数

参数与初始化

 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):

其中几个常用的参数:

  • dataset 数据集:map-style and iterable-style 可以用index取值的对象
  • batch_size: 大小
  • shuffle: 取batch是否随机取,默认为False。
  • sampler:定义取batch的方法,是一个迭代器,每次生成一个 k e y key key,用于读取dataset中的值.
  • batch_sampler:也是一个迭代器,每次生成一个batch_size的key。
  • num_workers: 参与工作的线程数。
  • collate_fn: 对取出的batch进行处理。
  • drop_last: 对最后不足的batchsize的数据的处理方法。
    下面看两段取自DataLoader中的__init__代码,帮助我们理解几个常用参数之间的关系。
	if sampler is None:  # give default samplers
	    if self._dataset_kind == _DatasetKind.Iterable:
	        # See NOTE [ Custom Samplers and IterableDataset ]
	        sampler = _InfiniteConstantSampler()
	    else:  # map-style
	        if shuffle:
	            sampler = RandomSampler(dataset)
	        else:
	            sampler = SequentialSampler(dataset)

可以看出,当dataset类型是map style时, s h u f f l e shuffle shuffle其实就是改变sampler的取值。

  • shuffle为默认false时,sampler是SequentialSampler 就是按顺序取样。
  • shuffle为true时,sample就是 R a n d o m S a m p l e r RandomSampler RandomSampler,就是按随机取样。
    所以当我们sampler有输入时,shuffle值就没有意义啦。后面我们再看sampler的定义方法。

在看一段初始化代码:

    if batch_size is not None and batch_sampler is None:
        # auto_collation without custom batch_sampler
        batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        
    self.sampler = sampler
    self.batch_sampler = batch_sampler

在看看,batchsamper的生成过程:

# 略去类的初始化
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

就是按照batch_size从sampler中读取索引,并形成生成器返回
以上可以看出 b a t c h s a m p l e r batch_sampler batchsampler s a m p l e r sampler sampler,batch_size和drop_last之间的关系。

  • 如果batch_sampler没有定义的话,且batch_size有定义,会根据:
    sampler,batch_szie,drop_last生成一个batch_sampler。

  • 自带的注释中对batch_sampler有一句话:
    Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.

  • 意思就是batch_sampler与这些参数冲突, 即如果你定义了batch_sampler,其他参数都不需要有了。

再看batch生成过程

每个batch都是由迭代器产生的

# DataLoader中iter的部分
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

# 再看调用的另一个类
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def __next__(self):
        index = self._next_index()  
        data = self._dataset_fetcher.fetch(index)  
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

对上面的代码进行一一略读,初始化略过。

  • 先对__next__index()一步一步溯源。

    def _next_index(self):
        return next(self._sampler_iter) 
    ///
	self._sampler_iter = iter(self._index_sampler)  
# 以上又用了一个迭代器生成索引 
 	self._index_sampler = loader._index_sampler
 	///
    def _index_sampler(self):
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler 

可以看出, _next_index其实就是batch_sample或是sampler用迭代器生成了一遍。
而sampler返回的就dataset中对应的索引值

在看_dataset_fetcher函数

 def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

# 按map-style往下看
class _MapDatasetFetcher(_BaseDatasetFetcher):
    
	# 略过初始化
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
            # 关键
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

可以看到DataLoader中迭代器生成的data,就是根据ssmpler或者batch_sampler生成的索引,从dataset中取值,然后经过collate_fn处理。
这里还要提一下auto_collation参数。

    def _auto_collation(self):
        return self.batch_sampler is not None

其实就是判断batch_sampler是否为None的情况,而根据batch_sampler的定义,只有初始化参数batch_size,和batch_sampler都为None时,才为False。
这时从fetch函数可以看出,就是每次取一个值,_next_index()取的也是sampler, 此时相当与batch_size等于1
由此,明白了整个大致过程,我们都可以i对sampler进行定义了。来获得我们想要的batch。

sampler参数的使用

sampler其实都是用来定义取batch的方法的一个函数或者类,返回的是一个迭代器

我们可以看一下自带的RandomSampler类中最重要的 i t e r iter iter函数。

    def __iter__(self):
        n = len(self.data_source)
        # dataset的长度, 按顺序索引
        if self.replacement:# 对应的replace参数
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        return iter(torch.randperm(n).tolist())        

可以看出,其实就是生成索引,然后随机的取值,然后再迭代
其实还有一些细节需要注意理解:

  • 比如__len__函数,包括DataLoader的len和sample的len
  • 两者区别,这部分代码比较简单,可以自行阅读,其实参考着:
  • RandomSampler写也不会出现问题。
  • 比如生成器和迭代器的使用,以及区别。

关于dataset预处理和collate_fn的一些问题。

在dataset预处理中,曾遇到这样一个问题,interrupted by signal 9: SIGKILL.

  • 经过查询才知道,是内存溢出,后来经过查看以上链接中michuanhaohao的代码才知道,预处理并不是对整个dataset同时进行预处理,然后传入DataLoader。而是把raw_dataset直接传入DataLoader,当读取一个batch时,
  • 对batch进行处理,这样确实节省内存。
class ImageDataset(Dataset):
    """Image Person ReID Dataset"""

    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img_path, pid, camid = self.dataset[index]
        img = read_image(img_path)

        if self.transform is not None:
            img = self.transform(img)

        return img, pid, camid, img_path

`
之前说过, DataLoader的参数dataset类型可以是列表,字典等。
可以用索引去读取值的类型。此处,则是一个类,类中定义了__getitem__函数,使其能够用index去取值。
这个类初始化输入的dataset其实都是图片地址和id参数,
而当有index来访问时,再去读取一个batch的图片,然后再对图片进行ransform。可以节省内存。

collate_fn 函数,可以从上面的fetch部分中看到, 也是对读取到的batch进行处理的一个对象,所以,预处理实际上也可以放在collate_fn中。``

#总结
慢慢的理解,争取多搞几篇文章,争取将该方法全部都将其搞定。
慢慢的自己理解透彻。理解彻底都行啦的样子与打算。

你可能感兴趣的:(模块复现,pytorch,python,深度学习)