通常在使用pytorch训练神经网络时,DataLoader模块是整个网络训练过程中的基础前提且尤为重要,其主要作用是根据传入接口的参数将训练集分为若干个大小为batch size的batch以及其他一些细节上的操作。一个典型的数据加载以及batch训练过程如下:(其中的args后面会详细解释)
loader = torch.utils.data.DataLoader(args)
for data, label in loader:
training
这次主要解读DataLoader模块的源码(Pytorch版本为1.8.0),在解读源码之前首先需要明确好几个概念。
两者从字面翻译层面来看分别为可迭代和迭代器的意思。这两者概念很相近,然而在底层实现上面有些区别。
iterable: 表示某个对象是可迭代的,底层只实现了__iter__方法
iterator: 表示某个对象是迭代器,底层不仅实现了__iter__,同时还实现了__next__方法
举个例子:python语言中的list、dict、str都是可迭代的,即可以使用for循环。而对于迭代器,不仅可以使用for循环来实现遍历,还可以通过next()函数来获取下一个元素。可以通过iter()函数来将一个可迭代对象转换为一个迭代器。DataLoader是可迭代的,不是迭代器。
#先获取iterator对象
it = iter([1,2,3,4,5])
while True:
try:
#获取下一个值
x = next(it);
except StopIteration:
# 遇到StopIteration就退出循环
break
本次源码解读针对Pytorch1.8版本。让我们从以下代码片段开始把!
loader = torch.utils.data.DataLoader(args)
for data, label in loader:
training
首先,解释一下DataLoader需要哪些参数以及这些参数的意义。
参数 | 参数意义 |
---|---|
dataset | 数据集(Dataset对象) |
shuffle | 是否在每个epoch之前重新打乱数据集的顺序 |
sampler | 采样器,定义从dataset中采样数据的方式(Sampler对象) |
batch_sampler | 同上,区别在于它一次返回的是一个batch大小的数据 |
num_workers | 处理数据加载的进程数量,等于0表示单进程 |
collate_fn | 将采样得到的样本整合成一个一个batch |
pin_memory | 在返回张量前,是否要将其加载到GPU中的固定内存中 |
drop_last | 数据集大小不被batch size整除时,是否将最后比较小的batch给丢弃 |
timeout | 数据加载的超时设置 |
worker_init_fn | 进程初始化函数 |
generator | 生成器 |
prefetch_factor | 多进程环境下的一个预取因素 |
persistent_workers | 多进程环境下,决定进程的生命周期 |
1.当我们调用torch.utils.data.DataLoader后,往底层逐步深入,调用DataLoader类中的**iter方法。下面来看一下iter**方法。
DataLoader.iter
可以看到,这里分了多进程和单进程的情况(注释写错了,应该是多进程)。分两种情况的原因其实作者已经用注释了,大致意思就是,单进程环境下,返回的迭代器应该在每次返回的时候都应该重新创建一次来重置它的状态;多进程环境下,返回的迭代器应该在DataLoader的整个声明周期中一直存在,这样做可以确保这个迭代器可以被多个进程重复使用。仔细看代码可以发现该方法调用了另外一个方法_get_iterator()。注意__iter__方法返回的是一个**_BaseDataLoaderIter**对象,这个类后面会细讲。
2.DataLoader._get_iterator
由于前面讲到的DataLoader对象是可迭代的,其内部没有实现__next__方法,因此在该类中需要自己定义一个迭代器来实现上层用户使用到的效果(即迭代器的效果)。因为单进程和多进程的处理逻辑不一样,因此返回的迭代器也不一样。值得注意的是,_SingleProcessDataLoaderIter和_MultiProcessingDataLoaderIter均继承了_BaseDataLoaderIter类
3._BaseDataLoaderIter
与DataLoader不同,_BaseDataLoaderIter是一个迭代器类,因此其不仅实现了__iter__方法还实现了**next**方法。其中__next__方法就是用来获取下一个元素从而实现数据的遍历。
仔细看代码可以发现,整个__next__方法中最重要最关键的是
data = self._next_data()这一句,而这个方法在该基类中并没有实现,因此需要继承它的子类去实现。
同样的,还可以看到代码中有self._sampler_iter,这是一个采样器迭代器,来获取一个batch中的数据的索引。
4.单进程的处理逻辑_SingleProcessDataLoaderIter
从上面的分析可以看出,_SingleProcessDataLoaderIter是_BaseDataLoaderIter的子类,其最主要的作用是去实现**_next_data**这个函数来实现单进程加载数据。
_SingleProcessDataLoaderIter.
分析_next_data函数,首先获取一个batch的数据的索引,然后使用一个fetcher来根据batch数据的索引来取出一个batch的数据,然后将这些数据整合成一个batch,最后再返回一个batch的数据。(self._next_index是其父类实现的)让我们来看看这个函数是怎么实现的
这里面的self._sampler_iter是Sampler对象对应的迭代器,根据Sampler对象的类型来返回数据。
至此,单进程的数据加载处理已经完成,让我们来捋一捋这整个流程的逻辑。这里只是粗略的讲了一遍整个的逻辑流程,其中还有很多细节,如果想知道这些细节,建议去看源码(有不理解的也可以来问我哦)
从上面的第二步中根据进程数的数量来决定返回哪个迭代器,那么在多进程环境下,返回的则是_MultiProcessingDataLoaderIter,同样的,该迭代器也是继承自_BaseDataLoaderIter类。与单进程不同的是,多进程采用了多个进程同时来加载数据,来提升数据加载的速度。
首先,需要了解一些数据结构:
_MultiProcessingDataLoaderIter.__init
这里仅列出比较重要的代码片段。上面的代码片段主要是启动了num_workers个进程,每个进程对应有一个index_queue,以及每个进程都运行_worker_loop这个函数。该函数的主要作用是从index_queue中取index、读数据、处理数据、返回数据、将数据插入data_queue中。
注意在这里一个worker每次只处理一个batch
_MultiProcessingDataLoaderIter.__reset
在该函数中,需要对每个worker预取若干个batch放入其index_queue中,防止在一启动进程时出现空队列的情况。
前面讲到,每个继承自_BaseDataLoaderIter类都需要实现_next_data函数,下面来看看,在多进程环境中,这个函数是如何实现的。
_MultiProcessingDataLoaderIter._next_data
首先看关键代码片段
首先从self._get_data()来获取batch的idx和data,此时获取完数据后,需要将_tasks_outstanding的值减一。为了保证现在从_data_queue中获取到的batch与预期中要获取的batch相同,引入一个if判断,如果与预期中的不同,则暂时将该batch的idx和数据暂时保存到_task_info中,防止重复取数据;如果相同,则从_task_info中删除该条记录,并将获取到的batch进行下一步的处理。
在取数据之前,我们需要借助_task_info来判断这个数据到底需要不需要调用_get_data。同样也是在_next_data函数中。
第一个while循环用来判断下一个要取出的batch的idx即_rcvd_idx是否还有数据,简而言之,就是来得到一个有效的_rcvd_idx。
_MultiProcessingDataLoaderIter._get_data
分析这段代码,可以看出其实这个函数在其内部实现调用了_try_get_data这个函数。下面让我们来看看_try_get_data是如何实现的。
_MultiProcessingDataLoaderIter._try_get_data
可以看到,这里是直接从_data_queue中取出数据,并返回取数据的状态以及数据。
_MultiProcessingDataLoaderIter._process_data
从上面的_next_data中可以看到,当取出的batch和预期要取出的batch相同时,就可以处理这个batch的数据了。整个处理逻辑如上代码所示。可以看到,每处理一个batch,_rcvd_idx就加一来表示下次要取出的batch的index,同时,每处理完一个batch,就需要将一个待处理的batch放入_index_queue中等待某个进程来处理。因此这里调用了_try_put_index.
_MultiProcessingDataLoaderIter._try_put_index
首先从sampler迭代器中获取下一个batch的index,然后将该index放入到第一个还存活的worker的_index_queue中。最后对一些数据结构的状态进行更新。
至此,多进程的数据加载处理流程解释完毕。多进程的处理逻辑与单进程唯一不同的地方在于将单进程中的_SingleProcessDataLoaderIter改为_MultiProcessingDataLoaderIter。其内部的_next_data方法十分不同,这个需要仔细体会,感兴趣的同学建议阅读源码。
数据加载模块在整个神经网络训练过程中十分重要,因此了解数据加载模块的底层实现有助于我们编写更加高效的代码。其次,整个数据加载模块的代码是用python编写,还为涉及到很多太底层的东西,阅读起来会相对轻松一点。
[1]torch.utils.data — PyTorch 1.9.0 documentation
[2]Pytorch Dataloader 学习笔记 · 大专栏 (dazhuanlan.com)
[3]PyTorch学习笔记(6)——DataLoader源代码剖析_g11d111的博客-CSDN博客_dataloader返回值