本文以 官方文档 为依据,不涉及大量的源码推导,仅在必要处展示示例代码,意在帮助各位读者理清楚Pytorch的数据加载机制,以根据自己的意志实现项目中的数据加载需求。
话不多说,我们首先来一个简单的实例演示,对整体有一个初步的把握与认知:
import torch
from torch.utils.data import *
# 先构建一个dataset
class my_tracking_dataset(Dataset):
def __init__(self, inputs, targets):
super(my_tracking_dataset, self).__init__()
self.inp = inputs
self.tgt = targets
self.len = len(inputs)
def __len__(self):
return self.len
def __getitem__(self, idx):
# 这里是以字典的方式返回数据项
disc = {
}
disc.update({
'inputs': self.inp[idx], 'targets': self.tgt[idx]})
return disc
# 构建dummy_input以及dummy_tgt, 这里我们假设有二十张三通道的128*128的图像,并且有对应的标签(识别任务)
dummy_input = torch.arange(20*3*128*128, dtype=torch.float32).view((20,3,128,128))
dummy_target = torch.arange(20)
data = my_tracking_dataset(dummy_input, dummy_target)
# 构建一个batchSampler
the_batchSampler = BatchSampler(SequentialSampler(range(20)), batch_size=4, drop_last=True)
# 构建一个dataloader
dataloader = DataLoader(data, batch_sampler=the_batchSampler)
for idx, batch in enumerate(dataloader):
print('idx:{}; input:{}, target:{}'.format(idx, batch['inputs'], batch['targets']))
# 构建一个随机的数据加载器
dataloader2 = DataLoader(data, shuffle=True, batch_size=4)
for idx, batch in enumerate(dataloader2):
print('idx:{}; input:{}, target:{}'.format(idx, batch['inputs'], batch['targets']))
感兴趣的话大家可以直接复制到自己的电脑上跑一下看看输出的结果,在基本库已经安装好的情况下,大概能够给你一个很直观的感受。
理清三者的关系,对我们进行数据加载的全局把控是十分关键的。我们首先来看一小段源码:(注意,这里省略了部分细节,仅展示核心组成部分,同时选择了num_workers=0的情况)
class DataLoader(object):
...
def __next__(self):
if self.num_workers == 0:
indices = next(self.sample_iter) # Sampler
batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
indices = next(self.sample_iter)
:可以看出,sampler的作用在于生成一组index,那么这组index具体作用是?batch = self.collate_fn([self.dataset[i] for i in indices])
:从这一行代码中我们可以看出,对于由sampler生成的一组index,我们抽取得到dataset中的具体数据并组织为一个列表将其传递给collate_fn方法并最终生成了一个batch。(具体来说collate_fn在这中间发挥了什么作用,我们后面再分析)batch = _utils.pin_memory.pin_memory_batch(batch)
:注意前面的条件判断,仅仅当pin_memory为true时我们才将batch整个转移到pined memory空间中,对于pined memory在这里大家不用太在意,仅需要知道该内存位置中的数据加载进入GPU的速度会大大提升(但也不是所有的数据类型都可以被防止在该空间中,有兴趣的同学可以自己查找资料了解)。DataLoader的定义方法:
torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None, *,
prefetch_factor=2,
persistent_workers=False)
以上即为dataloader的具体定义方法,在这里我们主要讲其中的几个重点以及需要特别注意的地方:
在这里需要插入一段说明:官方文档中明确说明包含有两种数据加载方式,一是automatic batching,另外则是self batching;两者的区别在于,前者是通过dataloader中的batch_size以及batch_sampler来进行组织的,也就是说我们并没有提前将数据组织为一个一个的batch,希望dataloader帮我们进行组织;而另一种方式,则是我们的dataset中就是已经完成组织的一个一个batch,我们需要dataloader一次载入一个即可,不需要它做额外的操作,这在某些情况下当然是更方便的,可以使我们更加自由地进行数据组织。
具体而言,当我们将batch_size以及batch_sampler都设置为None时,automatic batching就是关闭状态。对于dataloader而言,还有非常多需要注意的细节,例如如何多进程(num_workers)加载数据的时候避免部分操作的重复执行,如何进行子进程初始化等等,想进一步深入了解的同学可以自行检索资料,或者在评论区大家一起讨论讨论。
对于Dataset而言,总体上可以分为两种类型:一种是map-style的dataset,另一种则是iterabel-style的dataset。其中第一种是我们最常用的类型,至于第二种类型的dataset我们需要记住的是:它主要应用在数据加载“很昂贵”的场合,例如需要从数据库中进行数据提取等等,此外,由于iterabel的数据集自带序列,因此无法配合sampler使用。
对于map-style的dataset,在实际使用中,我们通常需要构建一个自己的类,它继承自torch.utils.data.dataset
,该子类一定要重写__getitem__()方法,len()方法则可以选择重写也可以不重写。其中最关键的__getitem__()方法,它需要能够实现接受一个idx/key,返回对应的数据项。该方法将在构建batch数据集的时候被调用,因此是必须的。
此外,如果我们想要使用的是self batching,那么在__getitem__()就需要能够返回一个已经被组织好的batch,具体的实现方法根据数据集的差异会有不同。
在前面的论述中我们已经对sampler的功能进行了总结,总体而言,sampler需要能够返回一段序列,以便于batch的生成。
在实际使用中,如果有需要(很多情况下,默认的sampler就已经够用了,可以通过shuffle来构建)我们可以构建一个Sampler的子类,继承自torch.utils.data.sampler
,必须重写的方法包括__iter__()以及__len__(),其中__iter__()方法用于生成对应idx的序列,__len__返回序列的总长度。
除了最常规的sampler以外,还有SequentialSampler(顺序生成序列),RamdomSampler等等。下面我们再介绍一种上面提到的BatchSampler:
torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
可以发现,batchsampler的参数中需要先提供一个sampler并且提供一个batch_size,同时我们也可以设置是否使用drop_last。下面以官方文档中的一小段代码作为示例:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
在上面的代码中,我们首先用SequentialSampler构建了一个基本的顺序序列,在将其作为Batchsampler的输入生成一个嵌套的序列。这里也可以很明显的看出:sampler的作用就是提供batch的构成序列信息。
本文对Pytorch框架中的Dataloader,Dataset以及Sampler三个元素进行介绍,理清了三者在构建pytorch数据集时所承担的角色,总结来说: