一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系

很多文章都是从 D a t a s e t Dataset Dataset等对象自下网上进行介绍的,但是对于初学者而言,其实这并不好理解,因为有时候,会不自觉的陷入到一些细枝末节中去,而不能把握重点,所以本文将自上而下的对 P y t o r c h Pytorch Pytorch数据读取方法进行介绍。

自上而下理解三者关系

首先,我们看一下 D a t a L o a d e r . n e x t DataLoader.next DataLoader.next的源代码长什么样,为方便理解,我只选取了num_works为0的情况,(num_works)简单理解都是能够并行化读取数据

 def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch

在阅读上面代码时候,我们可以假设,我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取的数据就只需要对应index即可,即上面代码中的 i n d i c e s indices indices,而选取index的方式有多种:有按顺序的,也有乱序的,所以这个工作需要 S a m p l e r Sampler Sampler来完成,现在你不需要具体的细节,后面会介绍,只需要了解 D a t a L o a d e r DataLoader DataLoader S a m p l e r Sampler Sampler在这里产生关系.
那么 D a t a s e t Dataset Dataset D a t a L o a d e r DataLoader DataLoader在什么时候产生关系呢?没错就是下面一行,我们已经拿到了 i n d i c e s indices indices,那么下一步,我们只需要根据 i n d i c e s indices indices对数据进行读取即可.

在下面 i f if if语句的作用都是,如果 p i n m e m o r y = T r u e , pin_memory=True, pinmemory=True,,那么 P y t o r c h Pytorch Pytorch会采用一系列操作把数据拷贝到GPU中,总之为了加速.

综上,可以了解DataLoader Sampler和Dataset三者关系如下:
一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系_第1张图片
在阅读后文中,始终需要将上面的关系记在心里,这样能帮助你更好的理解

Sampler

参数传递

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None)

要更加细致的理解 S a m p l e r Sampler Sampler原理,我们需要先阅读以下 D a t a L o a d e r DataLoader DataLoader的源代码 如下:
可以看到初始化参数有两种 S a m p l e r Sampler Sampler : Sampler和batch_sampler
都默认为None,前者作用是生成一系列 i n d e x index index,而batch_sampler则是将sampler生成indices打包分组,得到一个又一个batch的index,例如,下面所示示例:
Batchsampler将 S e q u e n t i a l S a m p l e r SequentialSampler SequentialSampler,生成的index按照指定的batchsize分组.

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

pyTorch已经实现的sampler有以下几种

  • SequentialSampler

  • RandomSampler

  • WeightedSampler

  • SubsetRandomSampler
    需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读理解源码更深刻的理解,这里只做总结:

  • 源码

    • 如果自定义batch_sampler,那么这些参数都必须使用默认值:batch_size Shuffle sampler drop_last.
    • 如果自定义了sampler :那么shuffle需要设置为false
    • 如果sampler和batch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:
      • 若shuffle = True时,则sampler=RandomSampler(dataset)
      • 若shuffle = False时,则sampler=SequentialSampler(dataset)

如果定义sampler和BatchSampler

仔细查看源代码可以发现,所有采样器其实都继承同一个父类,即 S a m p l e r Sampler Sampler,其代码定义如下:

class Sampler(object):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """
 
    def __init__(self, data_source):
        pass
 
    def __iter__(self):
        raise NotImplementedError
		
    def __len__(self):
        return len(self.data_source)

所以,你要做好的都是定义好__iter__(self) 函数,不过要注意的是该函数的返回值需要是可迭代的,例如 S e q u e n t i a l S a m p l e r SequentialSampler SequentialSampler返回的是:
iter(range(len(self.data_source)))
另外 B a t c h S a m p l e r BatchSampler BatchSampler与其他 S a m p l e r Sampler Sampler的主要区别是其需要将 S a m p l e r Sampler Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表,也就是说后面读取数据的过程中都是 b a t c h s a m p l e r batch sampler batchsampler.

Dataset

定义如下

class Dataset(object):
	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...

上面三个方法最基本的,其中__getitem__是最主要的方法,其规定了如何读取数据,但是其又不同于一般的方法,因为它是 p y t h o n b u i l t − i n python built-in pythonbuiltin方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问,加入你定义好一个dataset,那么可以直接通过dataset[0]来访问第一个数据,在之前,我一值没弄清__getitem__是什么作用,所以一值不知道该怎么进入这个函数进行调试,现在如果你想对__getitem__方法进行调试,可以写一个for循环遍历dataset来进行调试,而不用构建dataloader等一大堆东西啦,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:

class DataLoader(object): 
    ... 
     
    def __next__(self): 
        if self.num_workers == 0:   
            indices = next(self.sample_iter)  
            batch = self.collate_fn([self.dataset[i] for i in indices]) # this line 
            if self.pin_memory: 
                batch = _utils.pin_memory.pin_memory_batch(batch) 
            return batch

我们仔细可以发现,前面有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前,我们需要知道每个参数的含义:

  • indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表
  • self.dataset[i] 这里对第i个数据进行读取操作.
    一般来说:self.dataset[i]=(img, label)

我们不难猜出,collate_fn的作用就是将一个batch的数据进行合并的操作,默认的是collate_fn是将img和label分别合并成 i m g s imgs imgs l a b e l s labels labels,所以,如果你的__getitem__方法只是返回img,label.那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。

自己理解一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系_第2张图片

DataLoader Dataset和Sampler之间的关系

  • Sampler产生对数据进行采样
  • Dataset:产生数据
  • DataLoader将数据迭代产生batch_size数据格式.

总结

会自己看源代码,根据源代码了解,这里只是做总结
慢慢的将各种数据之间的关系都搞明白,全部都将其搞透彻.

你可能感兴趣的:(Torch的使用及参数解释,pytorch,python,深度学习)