一文理清Pytorch数据加载三大件——DataLoader,Dataset,Sampler

本文以 官方文档 为依据,不涉及大量的源码推导,仅在必要处展示示例代码,意在帮助各位读者理清楚Pytorch的数据加载机制,以根据自己的意志实现项目中的数据加载需求。

目录

    • 零、实例演示
    • 一、DataLoader,Dataset,Sampler之间如何协调工作
    • 二、详解DataLoader
    • 三、Dataset
    • 四、Sampler
    • 五、结语

零、实例演示

话不多说,我们首先来一个简单的实例演示,对整体有一个初步的把握与认知:

import torch
from torch.utils.data import *

# 先构建一个dataset
class my_tracking_dataset(Dataset):
    def __init__(self, inputs, targets):
        super(my_tracking_dataset, self).__init__()
        self.inp = inputs
        self.tgt = targets
        self.len = len(inputs)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        # 这里是以字典的方式返回数据项
        disc = {
     }
        disc.update({
     'inputs': self.inp[idx], 'targets': self.tgt[idx]})
        return disc
        
# 构建dummy_input以及dummy_tgt, 这里我们假设有二十张三通道的128*128的图像,并且有对应的标签(识别任务)
dummy_input = torch.arange(20*3*128*128, dtype=torch.float32).view((20,3,128,128))
dummy_target = torch.arange(20)

data = my_tracking_dataset(dummy_input, dummy_target)

# 构建一个batchSampler
the_batchSampler = BatchSampler(SequentialSampler(range(20)), batch_size=4, drop_last=True)

# 构建一个dataloader
dataloader = DataLoader(data, batch_sampler=the_batchSampler)

for idx, batch in enumerate(dataloader):
    print('idx:{}; input:{}, target:{}'.format(idx, batch['inputs'], batch['targets']))

# 构建一个随机的数据加载器
dataloader2 = DataLoader(data, shuffle=True, batch_size=4)
for idx, batch in enumerate(dataloader2):
    print('idx:{}; input:{}, target:{}'.format(idx, batch['inputs'], batch['targets']))

感兴趣的话大家可以直接复制到自己的电脑上跑一下看看输出的结果,在基本库已经安装好的情况下,大概能够给你一个很直观的感受。

一、DataLoader,Dataset,Sampler之间如何协调工作

理清三者的关系,对我们进行数据加载的全局把控是十分关键的。我们首先来看一小段源码:(注意,这里省略了部分细节,仅展示核心组成部分,同时选择了num_workers=0的情况)

class DataLoader(object):
    ...
    
    def __next__(self):
        if self.num_workers == 0:  
            indices = next(self.sample_iter)  # Sampler
            batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch
  • indices = next(self.sample_iter):可以看出,sampler的作用在于生成一组index,那么这组index具体作用是?
  • batch = self.collate_fn([self.dataset[i] for i in indices]):从这一行代码中我们可以看出,对于由sampler生成的一组index,我们抽取得到dataset中的具体数据并组织为一个列表将其传递给collate_fn方法并最终生成了一个batch。(具体来说collate_fn在这中间发挥了什么作用,我们后面再分析)
  • batch = _utils.pin_memory.pin_memory_batch(batch):注意前面的条件判断,仅仅当pin_memory为true时我们才将batch整个转移到pined memory空间中,对于pined memory在这里大家不用太在意,仅需要知道该内存位置中的数据加载进入GPU的速度会大大提升(但也不是所有的数据类型都可以被防止在该空间中,有兴趣的同学可以自己查找资料了解)。
    至此我们就得到了一次训练所需要的batch数据,整体而言:
  • DataLoader作为一个整体容器,为数据加载提供了工作流水线。
  • Dataset即我们的数据集(当然是以特定的格式进行存储)。
  • Sampler用于生成数据索引,该索引将用于batch的构成。

二、详解DataLoader

DataLoader的定义方法:

torch.utils.data.DataLoader(dataset, 
							batch_size=1, 
							shuffle=False, 
							sampler=None, 
							batch_sampler=None, 
							num_workers=0, 
							collate_fn=None, 
							pin_memory=False, 
							drop_last=False, 
							timeout=0, 
							worker_init_fn=None, 
							multiprocessing_context=None, 
							generator=None, *, 
							prefetch_factor=2, 
							persistent_workers=False)

以上即为dataloader的具体定义方法,在这里我们主要讲其中的几个重点以及需要特别注意的地方:

  • dataset:即为我们的Dataset,留待后面讲解
  • batch-size:即一个batch所包含的数据数量,对于图像数据可以理解为N维度大小。
  • shuffle:意义为随机顺序,如果该参数设置为True,那么就会为该DataLoader生成一个对应的sampler,下面的sampler不用再设置。
  • sampler:该参数既可以为一个sampler对象(序列生成器),也可以为一个iterable的对象,例如list等。这也正印证了前面的例子,sampler的作用就是生成一组indexes,因此我们直接提供一组indexes当然也是可以的。注意:如果这里提供了sampler,那么shuffle参数一定要为false。)
  • batch-sampler:该参数与sampler相同,既可以传递一个batch-sampler对象,也可以传递一个iterable对象,如嵌套list,E.g.[[1,2,3],[4,5,6],[7,8,9]],需要注意的是:batch-sampler的功能与sampler配合batch-size的作用是相同的(可能需要配合drop-last一起使用),区别在于普通的sampler会生成一串index,再由batch-size进行组织,batch-sampler直接一步到位。
  • num_workers:即用于读取数据的线程数量,默认为0,即仅仅由主程序进行数据读写。这里还有一个需要特别关注的点:num_workers在python中的的实现方式是subprocess,但是在不同的平台上有不同的表现,其中在Windows平台上使用时需要进行额外的配置。详细情况可以查看官方文档。
    在这里需要插入一段说明:官方文档中明确说明包含有两种数据加载方式,一是automatic batching,另外则是self batching;两者的区别在于,前者是通过dataloader中的batch_size以及batch_sampler来进行组织的,也就是说我们并没有提前将数据组织为一个一个的batch,希望dataloader帮我们进行组织;而另一种方式,则是我们的dataset中就是已经完成组织的一个一个batch,我们需要dataloader一次载入一个即可,不需要它做额外的操作,这在某些情况下当然是更方便的,可以使我们更加自由地进行数据组织。具体而言,当我们将batch_size以及batch_sampler都设置为None时,automatic batching就是关闭状态。
  • collate_fn:在前面对三大件关系的讨论中我们就了解过collate_fn,其作用在于将sampler筛选出来的数据组织成一个batch。需要注意的是,默认状态下的collate_fn在automatic batching开启与关闭的状态下行为是存在差异的,具体为:当automatic batching关闭时 collate一次接受一个数据单元(因为这时候数据单元已经是被我们提前组织好的batch),其作用仅仅是将数据格式转换为tensor以方便后续操作,其余部分不再变动; 当automatic batching开启时,collate_fn主要有三个特性:1.对于数据而言,总是会新生层一个batch维度,例如我们传入多张(ch, w, h)的图像,那么将返回(N, ch, w, h)的batch;2.将数据转换为tensor方便后续操作;3.保留原始的数据结构不变,例如你可以输入一个字典,那么返回的结果中将保持字典这个结构不变,不过字典中的值变味了tensor。
    对于dataloader而言,还有非常多需要注意的细节,例如如何多进程(num_workers)加载数据的时候避免部分操作的重复执行,如何进行子进程初始化等等,想进一步深入了解的同学可以自行检索资料,或者在评论区大家一起讨论讨论。

三、Dataset

对于Dataset而言,总体上可以分为两种类型:一种是map-style的dataset,另一种则是iterabel-style的dataset。其中第一种是我们最常用的类型,至于第二种类型的dataset我们需要记住的是:它主要应用在数据加载“很昂贵”的场合,例如需要从数据库中进行数据提取等等,此外,由于iterabel的数据集自带序列,因此无法配合sampler使用。
对于map-style的dataset,在实际使用中,我们通常需要构建一个自己的类,它继承自torch.utils.data.dataset,该子类一定要重写__getitem__()方法,len()方法则可以选择重写也可以不重写。其中最关键的__getitem__()方法,它需要能够实现接受一个idx/key,返回对应的数据项。该方法将在构建batch数据集的时候被调用,因此是必须的。
此外,如果我们想要使用的是self batching,那么在__getitem__()就需要能够返回一个已经被组织好的batch,具体的实现方法根据数据集的差异会有不同。

四、Sampler

在前面的论述中我们已经对sampler的功能进行了总结,总体而言,sampler需要能够返回一段序列,以便于batch的生成。
在实际使用中,如果有需要(很多情况下,默认的sampler就已经够用了,可以通过shuffle来构建)我们可以构建一个Sampler的子类,继承自torch.utils.data.sampler,必须重写的方法包括__iter__()以及__len__(),其中__iter__()方法用于生成对应idx的序列,__len__返回序列的总长度。
除了最常规的sampler以外,还有SequentialSampler(顺序生成序列),RamdomSampler等等。下面我们再介绍一种上面提到的BatchSampler:

torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

可以发现,batchsampler的参数中需要先提供一个sampler并且提供一个batch_size,同时我们也可以设置是否使用drop_last。下面以官方文档中的一小段代码作为示例:

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]

在上面的代码中,我们首先用SequentialSampler构建了一个基本的顺序序列,在将其作为Batchsampler的输入生成一个嵌套的序列。这里也可以很明显的看出:sampler的作用就是提供batch的构成序列信息。

五、结语

本文对Pytorch框架中的Dataloader,Dataset以及Sampler三个元素进行介绍,理清了三者在构建pytorch数据集时所承担的角色,总结来说:

  • DataLoader是数据加工厂,在具体的机器学习任务中,它能够为我们自动包装出一个一个的batch数据集,当然我们也可以从自己的需求出发构建我们想要的数据集。
  • Sampler是抽样器,对其最直接的理解是:sampler将根据我们的需求输出一串数字序列以构成当前的batch。
  • Dataset即为我们的数据集,在automatic batching环境下数据集中的数据项就是基础数据项(以图像数据为例,一个数据项可以是RGB三个通道数据以及一个label);而当我们以self batching进行数据组织时,dataset中的数据项将是已经组织好的一个一个batch。
    除了以上这些关键点,Pytorch的数据加载机制中还有非常多值得关注的信息与难点,本篇文章并不能很好地涵盖,如果有可能的话将会继续补充!感兴趣的同学可以自行检索学习。
    以上即为本文的全部内容,感谢各位的阅读,希望有收获的朋友可以的点赞+收藏!

你可能感兴趣的:(机器学习框架——pytorch,pytorch,python,机器学习)