PyTorch 数据IO+预处理部分阅读

PyTorch模型训练最开始就是数据读取以及预处理模块,而该模块包括了两个重要的入口,第一个是用于将disk中的数据读取路径预处理好的ImageFolder方法:

train_dataset = datasets.ImageFolder(traindir,transform_train)

第二个方法为真正做数据IO读取以及预处理的DataLoader方法,而真正训练的时候也是将train_loader传入到train()函数中,并且以迭代器的形式按照iter的形式取相对应的数据出来。

train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

下面我们详细的过一下这个流程,努力将本文写成全网最详细的PyTorch数据部分的解析手册:

  • 总解析顺序为:class ImageFolder(DatasetFolder) -> class DatasetFolder(VisionDataset) -> def make_dataset()、getitem ->self.loader(path) 样本读取 ->default_loader()图片读取最终函数

一、datasets.ImageFolder

按照顺序,我们首先要看一下class ImageFolder(DatasetFolder):这个类:

这个代码中就一个init函数进行初始化,传入了如下的参数:

Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
        is_valid_file (callable, optional): A function that takes path of an Image file
            and check if the file is a valid file (used to check of corrupt files)

其中最重要的部分是root中的dataset的路径以及transform对应的数据增强内容。

下面进入DatasetFolder类:解析写到代码中


class DatasetFolder(VisionDataset):
    """A generic data loader where the samples are arranged in this way: ::
目标就是将传入路径下的数据进行进一步组织,每一个数据变成string,方便后续读取
        root/class_x/xxx.ext
        root/class_x/xxy.ext
        root/class_x/xxz.ext
        root/class_y/123.ext
        root/class_y/nsdf3.ext
        root/class_y/asd932_.ext
    Args:
        root (string): Root directory path.
        loader (callable): A function to load a sample given its path.
        extensions (tuple[string]): A list of allowed extensions.
            both extensions and is_valid_file should not be passed.
        transform (callable, optional): A function/transform that takes in
            a sample and returns a transformed version.
            E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): A function/transform that takes
            in the target and transforms it.
        is_valid_file (callable, optional): A function that takes path of a file
            and check if the file is a valid file (used to check of corrupt files)
            both extensions and is_valid_file should not be passed.
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, root, loader, extensions=None, transform=None,
                 target_transform=None, is_valid_file=None):
        super(DatasetFolder, self).__init__(root, transform=transform,
                                            target_transform=target_transform)
        classes, class_to_idx = self._find_classes(self.root) #将路径下数据的类取出
        samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) #将目录下的数据按照string类型进行组织
        if len(samples) == 0:
            raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
                                "Supported extensions are: " + ",".join(extensions)))

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

其中make_dataset:

def make_dataset(
    directory: str,
    class_to_idx: Dict[str, int],
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    Args:
        directory (str): root dataset directory
        class_to_idx (Dict[str, int]): dictionary mapping class name to class index
        extensions (optional): A list of allowed extensions.
            Either extensions or is_valid_file should be passed. Defaults to None.
        is_valid_file (optional): A function that takes path of a file
            and checks if the file is a valid file
            (used to check of corrupt files) both extensions and
            is_valid_file should not be passed. Defaults to None.

    Raises:
        ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.

    Returns:
        List[Tuple[str, int]]: samples of a form (path_to_sample, class)
    """
    instances = []
    directory = os.path.expanduser(directory)
    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
    if extensions is not None:
        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
    is_valid_file = cast(Callable[[str], bool], is_valid_file)
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class] #class的名称与ID编号有一个对应关系
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = path, class_index #将排列好的数据dir以及class对应的id进行打包,统一放到instance中并返回
                    instances.append(item)
    return instances

getitem函数用于访问特定的数据,即传入index并调用sample = self.loader(path)进行IO读取。

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path) #读取数据
        if self.transform is not None:
            sample = self.transform(sample) # 数据增强对数据进行处理
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

其中的self.loader(path)loader: Callable[[str], Any] = default_loader,,即:

def default_loader(path: str) -> Any:
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

最终调用

def pil_loader(path: str) -> Image.Image:
    
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        # st = time.time()
        img = Image.open(f)
        img = img.convert('RGB')
        # print(f'img.convert {path}: {time.time() - st}')
        return img

回归最初的函数,我们得到的ImageFolder类主要是想利用samples-即数据的打包处理后的路径、以及targets-即对应的类别

二、torch.utils.data.DataLoader

1 worker数量为0

  • 解析顺序:DataLoader(object): -> iter -> _SingleProcessDataLoaderIter -> _MapDatasetFetcher数据读取器

之后我们将处理好的数据ImageFolder类传入torch.utils.data.DataLoader:

   train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

该类主要是用于提供一个iterable的数据集,方便train函数里面for循环读取预处理好的一个batch的数据。

这个代码相对复杂一些:

首先传入的参数分别为:训练数据的打包路径,batchsize的个数,是否shuffle,worker的数量,是否用pin等

init函数中首先就是各种传入参数的条件判断,安全保障等。

而这里最重要的函数我觉得是这里的iter函数:

    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()
    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

这里分两类,第一个是单线程的Dataloader,该方法会进入:

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration #生成下一个要访问的index
        data = self._dataset_fetcher.fetch(index)  # 调用数据读取方法,传入index
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

该函数首先创建_DatasetKind类的数据读取器,其中会进入_MapDatasetFetcher:

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index): # 获取数据的函数,且是sample的编号
        if self.auto_collation: 
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

而fetch函数中data = [self.dataset[idx] for idx in possibly_batched_index]会呼应上面的__getitem__函数,从而调用self.loader+transform通过IO和数据预处理获取到最新的数据,并放到data这个数组中。

2 worker数量大于0

  • 解析顺序:

而数据处理部分最重要的为多线程部分:

_MultiProcessingDataLoaderIter(_BaseDataLoaderIter)

这里给了一堆的英文解释,我们不用看这个部分,跟着我的思路阅读即可。

首先是该类中的初始化函数:

    def __init__(self, loader):
        super(_MultiProcessingDataLoaderIter, self).__init__(loader)

        assert self._num_workers > 0

        if loader.multiprocessing_context is None:
            multiprocessing_context = multiprocessing
        else:
            multiprocessing_context = loader.multiprocessing_context

        self._worker_init_fn = loader.worker_init_fn
        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
        self._worker_result_queue = multiprocessing_context.Queue() 存放处理好的数据
        self._worker_pids_set = False
        self._shutdown = False
        self._send_idx = 0  # idx of the next task to be sent to workers
        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
        # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
        # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
        #                  \ (worker_id, data)   if data is already fetched (out-of-order)
        self._task_info = {}
        self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
        self._workers_done_event = multiprocessing_context.Event()

        self._index_queues = []
        self._workers = []
        # A list of booleans representing whether each worker still has work to
        # do, i.e., not having exhausted its iterable dataset object. It always
        # contains all `True`s if not using an iterable-style dataset
        # (i.e., if kind != Iterable).
        self._workers_status = []
        for i in range(self._num_workers):
            index_queue = multiprocessing_context.Queue() **存放每个worker要取的数据的索引
            # index_queue.cancel_join_thread()
            w = multiprocessing_context.Process(   根据woker的数量创建多个进程
                target=_utils.worker._worker_loop, 具体的数据读取以及预处理的方法
                args=(self._dataset_kind, self._dataset, index_queue,
                      self._worker_result_queue, self._workers_done_event,
                      self._auto_collation, self._collate_fn, self._drop_last, 
                      self._base_seed + i, self._worker_init_fn, i, self._num_workers))
            w.daemon = True #守护进程
            # NB: Process.start() actually take some time as it needs to
            #     start a process and pass the arguments over via a pipe.
            #     Therefore, we only add a worker to self._workers list after
            #     it started, so that we do not call .join() if program dies
            #     before it starts, and __del__ tries to join but will get:
            #     AssertionError: can only join a started process.
            w.start()
            self._index_queues.append(index_queue)
            self._workers.append(w)
            self._workers_status.append(True)

        if self._pin_memory:
            self._pin_memory_thread_done_event = threading.Event()
            self._data_queue = queue.Queue()
            pin_memory_thread = threading.Thread(
                target=_utils.pin_memory._pin_memory_loop,
                args=(self._worker_result_queue, self._data_queue,
                      torch.cuda.current_device(),
                      self._pin_memory_thread_done_event))
            pin_memory_thread.daemon = True
            pin_memory_thread.start()
            # Similar to workers (see comment above), we only register
            # pin_memory_thread once it is started.
            self._pin_memory_thread = pin_memory_thread
        else:
            self._data_queue = self._worker_result_queue

        _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
        _utils.signal_handling._set_SIGCHLD_handler()
        self._worker_pids_set = True

        # prime the prefetch loop
        for _ in range(2 * self._num_workers): 提前生成2倍worker的数据获取请求
            self._try_put_index()

这里可以针对某些部分进行重点说明:
在循环中,代码会执行index_queue = multiprocessing_context.Queue(),为每一个进程创建一个新的索引queue,然后调用self._index_queues.append(index_queue)将所有进程的queue组织到一起_index_queues。

而在_try_put_index函数中:

    def _try_put_index(self):
        assert self._tasks_outstanding < 2 * self._num_workers
        try:
            index = self._next_index() #生成一个batch的数据索引
        except StopIteration:
            return
        for _ in range(self._num_workers):  # find the next active worker, if any
            worker_queue_idx = next(self._worker_queue_idx_cycle)
            if self._workers_status[worker_queue_idx]:
                break
        else:
            # not found (i.e., didn't break)
            return
将索引数据放入_index_queues中,因为python是地址更新,所以更新_index_queues[worker_queue_idx]也就是更新上面queue的index_queue位置上的具体值;
        self._index_queues[worker_queue_idx].put((self._send_idx, index)) #将该batch的索引放入
        self._task_info[self._send_idx] = (worker_queue_idx,) #记录这个batch是哪个进程做的
        self._tasks_outstanding += 1
        self._send_idx += 1 #batch操作的请求总数加一

index_queue将会被传入每个进程:

w = multiprocessing_context.Process(
                target=_utils.worker._worker_loop,
                args=(self._dataset_kind, self._dataset, index_queue, 
                      self._worker_result_queue, self._workers_done_event,
                      self._auto_collation, self._collate_fn, self._drop_last,
                      self._base_seed + i, self._worker_init_fn, i, self._num_workers))

下面我们来看重头戏,看看每个进程做的事情是什么:target=_utils.worker._worker_loop

def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
                 auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
                 num_workers):
    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.
    # index_queue:每个进程未来要处理的数据index
    # data_queue:处理好的数据放入的位置

    try:
        # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal had already happened
        # again.
        # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
        signal_handling._set_worker_signal_handlers()

        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)

        global _worker_info
        _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
                                  seed=seed, dataset=dataset)

        from torch.utils.data import _DatasetKind

        init_exception = None

        try:
            if init_fn is not None:
                init_fn(worker_id)

            fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) #建立数据读取器
        except Exception:
            init_exception = ExceptionWrapper(
                where="in DataLoader worker process {}".format(worker_id))

        # When using Iterable mode, some worker can exit earlier than others due
        # to the IterableDataset behaving differently for different workers.
        # When such things happen, an `_IterableDatasetStopIteration` object is
        # sent over to the main process with the ID of this worker, so that the
        # main process won't send more tasks to this worker, and will send
        # `None` to this worker to properly exit it.
        #
        # Note that we cannot set `done_event` from a worker as it is shared
        # among all processes. Instead, we set the `iteration_end` flag to
        # signify that the iterator is exhausted. When either `done_event` or
        # `iteration_end` is set, we skip all processing step and just wait for
        # `None`.
        iteration_end = False

        watchdog = ManagerWatchdog() #用来检查进程是否活着

        while watchdog.is_alive():
            try:
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) #获取需要处理的数据的index
            except queue.Empty:
                continue
            if r is None:
                # Received the final signal
                assert done_event.is_set() or iteration_end
                break
            elif done_event.is_set() or iteration_end:
                # `done_event` is set. But I haven't received the final signal
                # (None) yet. I will keep continuing until get it, and skip the
                # processing steps.
                continue
            idx, index = r #获取到id以及dir路径
            if init_exception is not None:
                data = init_exception
                init_exception = None
            else:
                try:
                    data = fetcher.fetch(index) # 将数据通过_DatasetKind.create_fetcher读出来
                except Exception as e:
                    if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
                        data = _IterableDatasetStopIteration(worker_id)
                        # Set `iteration_end`
                        #   (1) to save future `next(...)` calls, and
                        #   (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
                        iteration_end = True
                    else:
                        # It is important that we don't store exc_info in a variable.
                        # `ExceptionWrapper` does the correct thing.
                        # See NOTE [ Python Traceback Reference Cycle Problem ]
                        data = ExceptionWrapper(
                            where="in DataLoader worker process {}".format(worker_id))
            data_queue.put((idx, data))# 将IO读出来的数据以及与处理好的数据放到共用的result队列中
            del data, idx, index, r  # save memory
    except KeyboardInterrupt:
        # Main process will raise KeyboardInterrupt anyways.
        pass
    if done_event.is_set():
        data_queue.cancel_join_thread()
        data_queue.close()

当n个进程开启后,就会源源不断的进行数据预处理,而生产者部分我们了解了,那消费者是如何运转的呢?

读取的iterable数据当然是通过调用__next__函数:

    def __next__(self):
        data = self._next_data()
        self._num_yielded += 1
        if self._dataset_kind == _DatasetKind.Iterable and \
                self._IterableDataset_len_called is not None and \
                self._num_yielded > self._IterableDataset_len_called:
            warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                        "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                              self._num_yielded)
            if self._num_workers > 0:
                warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                             "IterableDataset replica at each worker. Please see "
                             "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
            warnings.warn(warn_msg)
        return data

该函数会进入:data = self._next_data(),即_MultiProcessingDataLoaderIter实现的

    def _next_data(self):
        while True:
            # If the worker responsible for `self._rcvd_idx` has already ended
            # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
            # we try to advance `self._rcvd_idx` to find the next valid index.
            #
            # This part needs to run in the loop because both the `self._get_data()`
            # call and `_IterableDatasetStopIteration` check below can mark
            # extra worker(s) as dead.
            while self._rcvd_idx < self._send_idx:
                info = self._task_info[self._rcvd_idx] #查看收到的batch size的请求是否已经被处理过
                worker_id = info[0]
                if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active,
                    #如果_rcvd_idx 这个batch的数据被处理过,那==2,没有被处理的话就是1(只有进程号,没有data数据),如果进程还活着,那么break出来(相信这个进程还能把这个数据做出来)
                    
                    break
                del self._task_info[self._rcvd_idx]
                self._rcvd_idx += 1 
            else:
                # no valid `self._rcvd_idx` is found (i.e., didn't break)
                self._shutdown_workers() # 进程全挂了进入这里
                raise StopIteration

            # Now `self._rcvd_idx` is the batch index we want to fetch

            # Check if the next sample has already been generated
            if len(self._task_info[self._rcvd_idx]) == 2: #如果数据已经被生成,那么pop出来
                data = self._task_info.pop(self._rcvd_idx)[1] # 1对应的是数据data
                return self._process_data(data) # 因为已经取出来一组处理好的batch数据,所以再放进去一个新的未处理的batch

            assert not self._shutdown and self._tasks_outstanding > 0
            # 如果数据一个都没有,主进程就会去找数据了
            idx, data = self._get_data() # 如果当前所需的batch没有被做出来,那么从result的池子中顺序取一个batch id + data
            self._tasks_outstanding -= 1
            
            if self._dataset_kind == _DatasetKind.Iterable:
                # Check for _IterableDatasetStopIteration
                if isinstance(data, _utils.worker._IterableDatasetStopIteration):
                    self._shutdown_worker(data.worker_id)
                    self._try_put_index() 
                    continue

            if idx != self._rcvd_idx:
                # store out-of-order samples
                self._task_info[idx] += (data,) #顺序取的这个如果是未来的,那么提前存好,然后继续进入循环
            else:
                del self._task_info[idx]
                return self._process_data(data) #取出来这个数据后,则发起新的IO读取请求

简单来说,主进程发起各种index的请求,并添加到每一个进程的index队列中,并且每一个batch都有一个编号,每个进程领一个编号去做。由于进程的执行速度不一样,所有batch是否做完并不是严格按照发起顺序的,所以后面的数据可以提前放进来,当主进程找不到所需的batch,那它便会通过self._get_data()_worker_result_queue中找到并在他自己的一个变量中维护,如果恰好从池子中拿的是所需的,就返回,若不是则提前存下来,方便下次需要这个batch的时候直接用。然后while Ture一直循环,直到拿到所需要的batch。

如果进程取数据的时候,其对应的index队列中不存在提前放好的index,则会持续continue;

try:
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) #获取需要处理的数据的index
            except queue.Empty: # 如果没有数据则一直循环等待
                continue

而这个index是一旦有Batch数据被取出,新的数据index就会放入,具体是哪个进程的index则会调用下面的cycle循环得出:

worker_queue_idx = next(self._worker_queue_idx_cycle)
self._index_queues[worker_queue_idx].put((self._send_idx, index))

所以这里每次取一个batch数据,就会循环产生下一个进程ID,产生新的index放进该进程对应的index队列中。所以并不是取进程n的一个Batch就一定把index补充给进程n,还是轮流来。

你可能感兴趣的:(PyTorch 数据IO+预处理部分阅读)