PyTorch1.0中的Dataloader

前面我们说到Dataset,是打包了数据地址、规模和其他的一些非必需功能,让Dataloader来调用。本文我们通过Dataloader的源码来看下原始数据是如何夹在到模型之中的。

Dataloader类构造函数__init__

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None):
        self.dataset = dataset					# 打包文件
        self.batch_size = batch_size				# batch_size
        self.num_workers = num_workers			# 多进程处理数据
        self.collate_fn = collate_fn				# 数据封装方式
        self.pin_memory = pin_memory				# 是否将载入的数据放到锁页内存中,如否,则允许放入虚拟内存,解释见 https://blog.csdn.net/tfcy694/article/details/83270701
        self.drop_last = drop_last				# 如果数据总数不能被batch_size整除,是否要丢弃最后一个小的batch
        self.timeout = timeout					# 载入超时限制
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_last')
            self.batch_size = None
            self.drop_last = None

        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers option cannot be negative; use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.__initialized = True

我们首先从简单的batch和shuffle开始说起:

  1. 如果想要自定义sampler或batch_sampler,那么shuffle选项必须关闭;也就意味着当需要通过shuffle指明数据打乱输入时,sampler和batch_sampler必须设置为None,然而内部逻辑中sampler和batch_sampler实际上反而又被指定为RandomSampler/SequentialSampler和相应的BatchSampler。
  2. 如果要自定义batch_sampler,那么batch_size必须设置为1,shuffle必须关闭,sampler必须关闭,且drop_last必须关闭。这是因为这些选项都已经包含在batch_sampler里面了,我们不能再自定义。

Dataloader类本身很简单,下面我们来看看其核心功能——迭代器的重载函数__iter__(self),即_DataLoaderIter类。

_DataLoaderIter类的构造函数__init__

_DataLoaderIter中的大多数变量均来自DataLoader,下文不再专门区分。
如果num_workers选项大于0,那么构造函数实现了多进程(workers)处理,然后根据pin_memory选项决定是否调用线程进行锁页内存加载。如果num_workers等于0,则略去上述多进程操作,采取普通方式读数据。
读取到的数据都将保存在队列self.data_queue中。这个变量很重要,我们在自己的程序中迭代输入数据时就是从这个队列中提取的数据,其元素形如(idx, samples)。定义在pin_memory和多进程封装之后。
有了大概的认识之后便能从 https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader 中给出的note了解一下数据加载进程的3条工作逻辑(退出逻辑):

  1. 当子进程/子线程迭代至最后一次时需要终止;
  2. loader进程、worker进程正常退出或出现error时,迭代器终止;
  3. 出现致命错误时,终止所有进程。
    # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
    #
    # Preliminary:
    #
    # Our data model looks like this (queues are indicated with curly brackets):
    #
    #                main process                              ||
    #                     |                                    ||
    #               {index_queue}                              ||
    #                     |                                    ||
    #              worker processes                            ||     DATA
    #                     |                                    ||
    #            {worker_result_queue}                         ||     FLOW
    #                     |                                    ||
    #      pin_memory_thread of main process                   ||   DIRECTION
    #                     |                                    ||
    #               {data_queue}                               ||
    #                     |                                    ||
    #                data output                               \/
    #
    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
    #      `pin_memory=False`.
    #
    #
    # Terminating multiprocessing logic requires very careful design. In
    # particular, we need to make sure that
    #
    #   1. The iterator gracefully exits the workers when its last reference is
    #      gone or it is depleted.
    #
    #      In this case, the workers should be gracefully exited because the
    #      main process may still need to continue to run, and we want cleaning
    #      up code in the workers to be executed (e.g., releasing GPU memory).
    #      Naturally, we implement the shutdown logic in `__del__` of
    #      DataLoaderIterator.
    #
    #      We delay the discussion on the logic in this case until later.
    #
    #   2. The iterator exits the workers when the loader process and/or worker
    #      processes exits normally or with error.
    #
    #      We set all workers and `pin_memory_thread` to have `daemon=True`.
    #
    #      You may ask, why can't we make the workers non-daemonic, and
    #      gracefully exit using the same logic as we have in `__del__` when the
    #      iterator gets deleted (see 1 above)?
    #
    #      First of all, `__del__` is **not** guaranteed to be called when
    #      interpreter exits. Even if it is called, by the time it executes,
    #      many Python core library resources may alreay be freed, and even
    #      simple things like acquiring an internal lock of a queue may hang.
    #      Therefore, in this case, we actually need to prevent `__del__` from
    #      being executed, and rely on the automatic termination of daemonic
    #      children. Thus, we register an `atexit` hook that sets a global flag
    #      `_python_exit_status`. Since `atexit` hooks are executed in reverse
    #      order of registration, we are guaranteed that this flag is set before
    #      library resources we use are freed. (Hooks freeing those resources
    #      are registered at importing the Python core libraries at the top of
    #      this file.) So in `__del__`, we check if `_python_exit_status` is set
    #      or `None` (freed), and perform no-op if so.
    #
    #      Another problem with `__del__` is also related to the library cleanup
    #      calls. When a process ends, it shuts the all its daemonic children
    #      down with a SIGTERM (instead of joining them without a timeout).
    #      Simiarly for threads, but by a different mechanism. This fact,
    #      together with a few implementation details of multiprocessing, forces
    #      us to make workers daemonic. All of our problems arise when a
    #      DataLoader is used in a subprocess, and are caused by multiprocessing
    #      code which looks more or less like this:
    #
    #          try:
    #              your_function_using_a_dataloader()
    #          finally:
    #              multiprocessing.util._exit_function()
    #
    #      The joining/termination mentioned above happens inside
    #      `_exit_function()`. Now, if `your_function_using_a_dataloader()`
    #      throws, the stack trace stored in the exception will prevent the
    #      frame which uses `DataLoaderIter` to be freed. If the frame has any
    #      reference to the `DataLoaderIter` (e.g., in a method of the iter),
    #      its  `__del__`, which starts the shutdown procedure, will not be
    #      called. That, in turn, means that workers aren't notified. Attempting
    #      to join in `_exit_function` will then result in a hang.
    #
    #      For context, `_exit_function` is also registered as an `atexit` call.
    #      So it is unclear to me (@ssnl) why this is needed in a finally block.
    #      The code dates back to 2008 and there is no comment on the original
    #      PEP 371 or patch https://bugs.python.org/issue3050 (containing both
    #      the finally block and the `atexit` registration) that explains this.
    #
    #      Another choice is to just shutdown workers with logic in 1 above
    #      whenever we see an error in `next`. This isn't ideal because
    #        a. It prevents users from using try-catch to resume data loading.
    #        b. It doesn't prevent hanging if users have references to the
    #           iterator.
    #
    #   3. All processes exit if any of them die unexpectedly by fatal signals.
    #
    #      As shown above, the workers are set as daemonic children of the main
    #      process. However, automatic cleaning-up of such child processes only
    #      happens if the parent process exits gracefully (e.g., not via fatal
    #      signals like SIGKILL). So we must ensure that each process will exit
    #      even the process that should send/receive data to/from it were
    #      killed, i.e.,
    #
    #        a. A process won't hang when getting from a queue.
    #
    #           Even with carefully designed data dependencies (i.e., a `put()`
    #           always corresponding to a `get()`), hanging on `get()` can still
    #           happen when data in queue is corrupted (e.g., due to
    #           `cancel_join_thread` or unexpected exit).
    #
    #           For child exit, we register SIGCHLD handler on main process,
    #           which checks if any of the workers fail in the (Python) handler.
    #           See DataLoader.cpp.
    #
    #           For `.get()` calls where the sender(s) is not the workers, we
    #           guard them with timeouts, and check the status of the sender
    #           when timeout happens:
    #             + in the workers, the `ManagerWatchdog` class checks the main
    #               process status.
    #             + if `pin_memory=True`, when getting from `pin_memory_thread`,
    #               check `pin_memory_thread` status periodically until `.get()`
    #               returns or see that `pin_memory_thread` died.
    #
    #        b. A process won't hang when putting into a queue;
    #
    #           We use `mp.Queue` which has a separate background thread to put
    #           objects from an unbounded buffer array. The background thread is
    #           daemonic and usually automatically joined when the process
    #           exits.
    #
    #           However, in case that the receiver has ended abruptly while
    #           reading from the pipe, the join will hang forever. Therefore,
    #           for both `worker_result_queue` (worker -> main process/pin_memory_thread)
    #           and each `index_queue` (main process -> worker), we use
    #           `q.cancel_join_thread()` in sender process before any `q.put` to
    #           prevent this automatic join.
    #
    #           Moreover, having all queues called `cancel_join_thread` makes
    #           implementing graceful shutdown logic in `__del__` much easier.
    #           It won't need to get from any queue, which would also need to be
    #           guarded by periodic status checks.
    #
    #           Note that this may leave corrupted data in the queue, but we
    #           don't care about the data anyways once we are shutting down.
    #
    #
    # Now let's get back to 1:
    #   how we gracefully exit the workers when the last reference to the
    #   iteartor is gone.
    #
    # To achieve this, we implement the following logic along with the design
    # choices mentioned above:
    #
    # [worker processes]
    #   While loader process is alive:
    #     Get from index_queue.
    #       If got a `None`, exit.
    #       If get anything else,
    #          Check `done_event`.
    #            If set, continue to next iteration
    #                    i.e., keep getting until see the `None`, then exit.
    #            Otherwise, process data.
    #       If timed out,
    #          No matter `done_event` is set (still need to see `None`) or not,
    #          must continue to next iteration .
    #
    # [pin_memory_thread]
    #   # No need to check main thread. If this thread is alive, the main loader
    #   # thread must be alive, because this thread is set as daemonic.
    #   While True:
    #     Get from index_queue.
    #       If got a `None`, exit.
    #       If get anything else,
    #          Check `done_event`.
    #            If set, continue to next iteration
    #                    i.e., keep getting until see the `None`, then exit.
    #            Otherwise, process data.
    #
    #   NOTE: we don't check the status of the main thread because
    #           1. if the process is killed by fatal signal, `pin_memory_thread`
    #              ends.
    #           2. in other cases, either the cleaning-up in __del__ or the
    #              automatic exit of daemonic thread will take care of it.
    #              This won't busy-wait either because `.get(timeout)` does not
    #              busy-wait.
    #
    # [main process]
    #   In the DataLoader Iter's `__del__`
    #     a. Set `done_event` (shared with `pin_memory_thread` and workers).
    #
    #        Note: from here on, the workers & `pin_memory_thread` may exit at
    #              any time after they receive `None`.
    #
    #     b. Exit `pin_memory_thread`
    #          i.   Put `None` in `worker_result_queue`.
    #          ii.  Join the `pin_memory_thread`.
    #
    #     c. Exit the workers.
    #          i.   Put `None` in each worker's `index_queue`.
    #          ii.  Join the workers.
    #
    #        NOTE: This has to be after (b) because it may leave corrupted data
    #              in `worker_result_queue`, which `pin_memory_thread` reads
    #              from.
    #
    #   NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
    #         can be omitted
    #
    # NB: `done_event`s isn't strictly needed. E.g., we can just check for
    #     `None` from `index_queue`, but it allows us to skip wasting resources
    #     processing indices already in `index_queue` if we are already shutting
    #     down.

多进程处理中的几个变量:

self.worker_result_queue:存放所有worker进程产生的数据,元素是一个形如(idx, samples)的元组,idx是batch序号,samples是该batch所包含的 由self.collate_fn产生的 dataset元素的集合。
self.batches_outstanding:当前已经准备好的数据。先在_put_indices中加1,再在__next__中减1。有加有减说明去数据正常;如果无法加1,说明数据集已经遍历完成,将退出。
self.send_idx:当前要被self._put_indices()放到self.index_queues中的batch序号
self.rcvd_idx:当前要从self.reorder_dict取的batch序号,将在介绍\_\_next\_\_()函数时深入讲解
self.reorder_dict:对多线程下的乱序batch进行排序
self.index_queues:存放若干个worker进程产生数据队列的列表,元素是self.num_workers个Queue
self.workers:存放若干个worker进程的列表,元素是self.num_workers个Process

_DataLoaderIter类的_get_batch()函数

仅在多进程中有用,用于self.data_queue出队。出队后的数据将存入self.reorder_dict字典中

_DataLoaderIter类的循环函数__next__()

重点说一下多进程下self.rcvd_idx的作用:因为多进程下,self.data_queue中元组的idx序号是乱的,而我们希望借助self.rcvd_idx以idx升序产生batch,所以引入self.reorder_dictz作为中继,对self.data_queue中batch的idx排序后输出。

_DataLoaderIter类的_put_indices()函数

把self.sample_iter筛出的数据集合(online或batch)输入到self.index_queues

其他的终止函数_shutdown_workers()和析构函数__del__()不再详述,又兴趣的朋友可以自己看看。

水平有限,文中可能有些许错误,欢迎交流。

参考资料:
https://blog.csdn.net/qq_36653505/article/details/83351808
https://discuss.pytorch.org/t/document-inconsistency-in-dataloader-and-torch-initial-seed/20823
https://blog.csdn.net/u014380165/article/details/79058479
https://discuss.pytorch.org/t/dataloader-multi-threading-random-number/27719
https://blog.csdn.net/u012436149/article/details/78545766

你可能感兴趣的:(PyTorch)