Dataloader重要参数与内部机制

@[TOC]

一、pytorch数据输入

Dataset负责生产数据,DataLoader负责数据的分批(batch_size)、采样(sampler)、传输
Pytorch版本:1.0.1

1. Dataset

继承torch.utils.data.Dataset,实现两个函数即可:

  • def len(self) 数据总数
  • def getitem(self, index) 根据下标获取其中一条数据

2. DataLoader

将Dataset作为参数,构造一个torch.utils.data.DataLoader对象即可。
DataLoader其他参数见下文。

二、Dataloader参数汇总

  • dataset(Dataset):
    传入的数据集

  • batch_size(int, optional):
    每个batch有多少个样本

  • shuffle(bool, optional):
    在每个epoch开始的时候,对数据进行重新打乱

  • sampler(Sampler, optional):
    自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False

  • batch_sampler(Sampler, optional):
    与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

  • num_workers (int, optional):
    这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)

  • collate_fn (callable, optional):
    将一个list的sample组成一个mini-batch的函数

  • pin_memory (bool, optional):
    如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.

  • drop_last (bool, optional):
    如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
    如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。

  • timeout(numeric, optional):
    如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

  • worker_init_fn (callable, optional):
    每个worker初始化函数 If not None, this will be called on each
    worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

2.1 sampler:分布式训练需DistributedSampler

train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)

DataLoader构造函数中相关代码:

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)  ##如果shuffer就随机  
                else:
                    sampler = SequentialSampler(dataset)  ##否则顺序采样  
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler  

batch_sampler是sampler的封装,可自定义批次数据的构造。默认BatchSampler相关源码:

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)      ##遍历sampler获取数据,满batch_size就yield  
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

2.2 collate_fn:将batch的数据重新组装

例如cirtorch中将数据拆成input_data和target两个数据。
因Dataset中get_item返回input_data和target两个值,如果不用该函数,每个batch的数据应该是[batch_size,2(先input_data再target),,,],经过该函数将变成([batch_size,,,],[batch_size,,]),第一个数据全是input_data,第二个数据全是target。

2.3 pin_memory=True:提高数据从cpu到gpu传输效率

pin_memory可在cpu主存(内存)中分配不可交换到swap(缓存)的内存。。默认内存分配中的数据都可交换到swap中,那CUDA驱动会通过DRAM机制将数据从内存传到GPU显存时会复制2次(先复制到一临时不可见pinned固定内存,再往显存中复制),因此pin_memory=True可提高约2倍cpu到gpu传输效率(.cuda()或 .to(device)的时候)。相见CPU和GPU内存交互。

【拓展】Elasticsearch中的Memlock(内存锁定)可申请固定大小且不可交换内存空间。

三、DataLoader的并行

# Our data model looks like this (queues are indicated with curly brackets):
    #
    #                main process                              ||
    #                     |                                    ||
    #               {index_queue}                              ||
    #                     |                                    ||
    #              worker processes                            ||     DATA
    #                     |                                    ||
    #            {worker_result_queue}                         ||     FLOW
    #                     |                                    ||
    #      pin_memory_thread of main process                   ||   DIRECTION
    #                     |                                    ||
    #               {data_queue}                               ||
    #                     |                                    ||
    #                data output                               \/
    #
    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
    #      `pin_memory=False`.
    
  • 基于multiprocessing多进程
  • 每个子进程的输入输出,通过两个主要的队列(multiprocessing.Queue()): index_queue要处理的下标、worker_result_queue要返回的下标。
  • 每个worker一次产生一个batch的数据
  • 返回batch数据前放入下一个批次数据下标
  • 构造函数子进程初始化:
            self.index_queues = []
            self.workers = []
            for i in range(self.num_workers):
                index_queue = multiprocessing.Queue() # 1.每个子进程一个队列放要处理的下标
                index_queue.cancel_join_thread()
                w = multiprocessing.Process(
                    target=_utils.worker._worker_loop, # 每个子进程循环执行的函数  
                    args=(self.dataset, index_queue,
                          self.worker_result_queue, self.done_event, #2.self.worker_result_queue 多子进程公用要返回batch数据的队列  
                          self.collate_fn, base_seed + i,
                          self.worker_init_fn, i))
                w.daemon = True
                # NB: Process.start() actually take some time as it needs to
                #     start a process and pass the arguments over via a pipe.
                #     Therefore, we only add a worker to self.workers list after
                #     it started, so that we do not call .join() if program dies
                #     before it starts, and __del__ tries to join but will get:
                #     AssertionError: can only join a started process.
                w.start()
                self.index_queues.append(index_queue)
                self.workers.append(w)

3.1 index_queue 要处理的数据下标

每个worker有一个index_queue dataloader.py#L544-L552
每个worker从index_queue取要处理的下标 dataloader.py#L124
dataloader输出一次数据前先往index_queue中放一次下标, _process_next_batch函数:

    def _process_next_batch(self, batch):
        self.rcvd_idx += 1
        self._put_indices()  ## 先放下一批数据下标
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        return batch         ## 再返回该批数据

_put_indices依次往不同worker所属的index_queue中放 dataloader.py#L644-L652

完整的dataloader next函数:

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch) ## 5. 之前以及取出来该下标数据,直接返回

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:  ## 1.直到取的数据下标正确才return
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self._get_batch()  ## 2.从worker_result_queue中获取数据
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch  ## 3.下标不对先存一下
                continue
            return self._process_next_batch(batch) ## 4.内部先放下一批数据下标再返回batch数据  

3.2 worker_result_queue 返回结果

每个worker一直在执行的循环_worker_loop,其中worker_result_queue作为_worker_loop函数的data_queue传入(dataloader.py#L544-L552),相见:

def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.

    try:
        global _use_shared_memory
        _use_shared_memory = True

        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal happened again already.
        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
        _set_worker_signal_handlers()

        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)

        data_queue.cancel_join_thread()

        if init_fn is not None:
            init_fn(worker_id)

        watchdog = ManagerWatchdog()

        while watchdog.is_alive():
            try:
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) ##从index_queue中获取要处理的下标
            except queue.Empty:
                continue
            if r is None:
                # Received the final signal
                assert done_event.is_set()
                return
            elif done_event.is_set():
                # Done event is set. But I haven't received the final signal
                # (None) yet. I will keep continuing until get it, and skip the
                # processing steps.
                continue
            idx, batch_indices = r
            try:
                samples = collate_fn([dataset[i] for i in batch_indices]) ##1.根据下标取样本数据  
            except Exception:
                # It is important that we don't store exc_info in a variable,
                # see NOTE [ Python Traceback Reference Cycle Problem ]
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else: ## 2. 没有抛异常就将样本数据放入结果返回队列  
                data_queue.put((idx, samples))
                del samples
    except KeyboardInterrupt:
        # Main process will raise KeyboardInterrupt anyways.
        pass

参考文献

  • https://github.com/pytorch/pytorch/blob/v1.0.1/torch/utils/data/dataloader.py
  • https://blog.csdn.net/g11d111/article/details/81504637
  • CPU和GPU内存交互

你可能感兴趣的:(Dataloader重要参数与内部机制)