前面我们说到Dataset,是打包了数据地址、规模和其他的一些非必需功能,让Dataloader来调用。本文我们通过Dataloader的源码来看下原始数据是如何夹在到模型之中的。
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None):
self.dataset = dataset # 打包文件
self.batch_size = batch_size # batch_size
self.num_workers = num_workers # 多进程处理数据
self.collate_fn = collate_fn # 数据封装方式
self.pin_memory = pin_memory # 是否将载入的数据放到锁页内存中,如否,则允许放入虚拟内存,解释见 https://blog.csdn.net/tfcy694/article/details/83270701
self.drop_last = drop_last # 如果数据总数不能被batch_size整除,是否要丢弃最后一个小的batch
self.timeout = timeout # 载入超时限制
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_last')
self.batch_size = None
self.drop_last = None
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers option cannot be negative; use num_workers=0 to disable multiprocessing.')
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True
我们首先从简单的batch和shuffle开始说起:
Dataloader类本身很简单,下面我们来看看其核心功能——迭代器的重载函数__iter__(self),即_DataLoaderIter类。
_DataLoaderIter中的大多数变量均来自DataLoader,下文不再专门区分。
如果num_workers选项大于0,那么构造函数实现了多进程(workers)处理,然后根据pin_memory选项决定是否调用线程进行锁页内存加载。如果num_workers等于0,则略去上述多进程操作,采取普通方式读数据。
读取到的数据都将保存在队列self.data_queue中。这个变量很重要,我们在自己的程序中迭代输入数据时就是从这个队列中提取的数据,其元素形如(idx, samples)。定义在pin_memory和多进程封装之后。
有了大概的认识之后便能从 https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader 中给出的note了解一下数据加载进程的3条工作逻辑(退出逻辑):
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
#
# Preliminary:
#
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.
#
#
# Terminating multiprocessing logic requires very careful design. In
# particular, we need to make sure that
#
# 1. The iterator gracefully exits the workers when its last reference is
# gone or it is depleted.
#
# In this case, the workers should be gracefully exited because the
# main process may still need to continue to run, and we want cleaning
# up code in the workers to be executed (e.g., releasing GPU memory).
# Naturally, we implement the shutdown logic in `__del__` of
# DataLoaderIterator.
#
# We delay the discussion on the logic in this case until later.
#
# 2. The iterator exits the workers when the loader process and/or worker
# processes exits normally or with error.
#
# We set all workers and `pin_memory_thread` to have `daemon=True`.
#
# You may ask, why can't we make the workers non-daemonic, and
# gracefully exit using the same logic as we have in `__del__` when the
# iterator gets deleted (see 1 above)?
#
# First of all, `__del__` is **not** guaranteed to be called when
# interpreter exits. Even if it is called, by the time it executes,
# many Python core library resources may alreay be freed, and even
# simple things like acquiring an internal lock of a queue may hang.
# Therefore, in this case, we actually need to prevent `__del__` from
# being executed, and rely on the automatic termination of daemonic
# children. Thus, we register an `atexit` hook that sets a global flag
# `_python_exit_status`. Since `atexit` hooks are executed in reverse
# order of registration, we are guaranteed that this flag is set before
# library resources we use are freed. (Hooks freeing those resources
# are registered at importing the Python core libraries at the top of
# this file.) So in `__del__`, we check if `_python_exit_status` is set
# or `None` (freed), and perform no-op if so.
#
# Another problem with `__del__` is also related to the library cleanup
# calls. When a process ends, it shuts the all its daemonic children
# down with a SIGTERM (instead of joining them without a timeout).
# Simiarly for threads, but by a different mechanism. This fact,
# together with a few implementation details of multiprocessing, forces
# us to make workers daemonic. All of our problems arise when a
# DataLoader is used in a subprocess, and are caused by multiprocessing
# code which looks more or less like this:
#
# try:
# your_function_using_a_dataloader()
# finally:
# multiprocessing.util._exit_function()
#
# The joining/termination mentioned above happens inside
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
# throws, the stack trace stored in the exception will prevent the
# frame which uses `DataLoaderIter` to be freed. If the frame has any
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
# its `__del__`, which starts the shutdown procedure, will not be
# called. That, in turn, means that workers aren't notified. Attempting
# to join in `_exit_function` will then result in a hang.
#
# For context, `_exit_function` is also registered as an `atexit` call.
# So it is unclear to me (@ssnl) why this is needed in a finally block.
# The code dates back to 2008 and there is no comment on the original
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
# the finally block and the `atexit` registration) that explains this.
#
# Another choice is to just shutdown workers with logic in 1 above
# whenever we see an error in `next`. This isn't ideal because
# a. It prevents users from using try-catch to resume data loading.
# b. It doesn't prevent hanging if users have references to the
# iterator.
#
# 3. All processes exit if any of them die unexpectedly by fatal signals.
#
# As shown above, the workers are set as daemonic children of the main
# process. However, automatic cleaning-up of such child processes only
# happens if the parent process exits gracefully (e.g., not via fatal
# signals like SIGKILL). So we must ensure that each process will exit
# even the process that should send/receive data to/from it were
# killed, i.e.,
#
# a. A process won't hang when getting from a queue.
#
# Even with carefully designed data dependencies (i.e., a `put()`
# always corresponding to a `get()`), hanging on `get()` can still
# happen when data in queue is corrupted (e.g., due to
# `cancel_join_thread` or unexpected exit).
#
# For child exit, we register SIGCHLD handler on main process,
# which checks if any of the workers fail in the (Python) handler.
# See DataLoader.cpp.
#
# For `.get()` calls where the sender(s) is not the workers, we
# guard them with timeouts, and check the status of the sender
# when timeout happens:
# + in the workers, the `ManagerWatchdog` class checks the main
# process status.
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
# check `pin_memory_thread` status periodically until `.get()`
# returns or see that `pin_memory_thread` died.
#
# b. A process won't hang when putting into a queue;
#
# We use `mp.Queue` which has a separate background thread to put
# objects from an unbounded buffer array. The background thread is
# daemonic and usually automatically joined when the process
# exits.
#
# However, in case that the receiver has ended abruptly while
# reading from the pipe, the join will hang forever. Therefore,
# for both `worker_result_queue` (worker -> main process/pin_memory_thread)
# and each `index_queue` (main process -> worker), we use
# `q.cancel_join_thread()` in sender process before any `q.put` to
# prevent this automatic join.
#
# Moreover, having all queues called `cancel_join_thread` makes
# implementing graceful shutdown logic in `__del__` much easier.
# It won't need to get from any queue, which would also need to be
# guarded by periodic status checks.
#
# Note that this may leave corrupted data in the queue, but we
# don't care about the data anyways once we are shutting down.
#
#
# Now let's get back to 1:
# how we gracefully exit the workers when the last reference to the
# iteartor is gone.
#
# To achieve this, we implement the following logic along with the design
# choices mentioned above:
#
# [worker processes]
# While loader process is alive:
# Get from index_queue.
# If got a `None`, exit.
# If get anything else,
# Check `done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data.
# If timed out,
# No matter `done_event` is set (still need to see `None`) or not,
# must continue to next iteration .
#
# [pin_memory_thread]
# # No need to check main thread. If this thread is alive, the main loader
# # thread must be alive, because this thread is set as daemonic.
# While True:
# Get from index_queue.
# If got a `None`, exit.
# If get anything else,
# Check `done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data.
#
# NOTE: we don't check the status of the main thread because
# 1. if the process is killed by fatal signal, `pin_memory_thread`
# ends.
# 2. in other cases, either the cleaning-up in __del__ or the
# automatic exit of daemonic thread will take care of it.
# This won't busy-wait either because `.get(timeout)` does not
# busy-wait.
#
# [main process]
# In the DataLoader Iter's `__del__`
# a. Set `done_event` (shared with `pin_memory_thread` and workers).
#
# Note: from here on, the workers & `pin_memory_thread` may exit at
# any time after they receive `None`.
#
# b. Exit `pin_memory_thread`
# i. Put `None` in `worker_result_queue`.
# ii. Join the `pin_memory_thread`.
#
# c. Exit the workers.
# i. Put `None` in each worker's `index_queue`.
# ii. Join the workers.
#
# NOTE: This has to be after (b) because it may leave corrupted data
# in `worker_result_queue`, which `pin_memory_thread` reads
# from.
#
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
# can be omitted
#
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
# `None` from `index_queue`, but it allows us to skip wasting resources
# processing indices already in `index_queue` if we are already shutting
# down.
多进程处理中的几个变量:
self.worker_result_queue:存放所有worker进程产生的数据,元素是一个形如(idx, samples)的元组,idx是batch序号,samples是该batch所包含的 由self.collate_fn产生的 dataset元素的集合。
self.batches_outstanding:当前已经准备好的数据。先在_put_indices中加1,再在__next__中减1。有加有减说明去数据正常;如果无法加1,说明数据集已经遍历完成,将退出。
self.send_idx:当前要被self._put_indices()放到self.index_queues中的batch序号
self.rcvd_idx:当前要从self.reorder_dict取的batch序号,将在介绍\_\_next\_\_()函数时深入讲解
self.reorder_dict:对多线程下的乱序batch进行排序
self.index_queues:存放若干个worker进程产生数据队列的列表,元素是self.num_workers个Queue
self.workers:存放若干个worker进程的列表,元素是self.num_workers个Process
仅在多进程中有用,用于self.data_queue出队。出队后的数据将存入self.reorder_dict字典中
重点说一下多进程下self.rcvd_idx的作用:因为多进程下,self.data_queue中元组的idx序号是乱的,而我们希望借助self.rcvd_idx以idx升序产生batch,所以引入self.reorder_dictz作为中继,对self.data_queue中batch的idx排序后输出。
把self.sample_iter筛出的数据集合(online或batch)输入到self.index_queues
水平有限,文中可能有些许错误,欢迎交流。
参考资料:
https://blog.csdn.net/qq_36653505/article/details/83351808
https://discuss.pytorch.org/t/document-inconsistency-in-dataloader-and-torch-initial-seed/20823
https://blog.csdn.net/u014380165/article/details/79058479
https://discuss.pytorch.org/t/dataloader-multi-threading-random-number/27719
https://blog.csdn.net/u012436149/article/details/78545766