pytorch 中的数据集

总述

需要 Dataset + collate_fn + Sampler + DataLoader 联用, 才等价于 tf 的 dataset.

  • DataLoader, 对外服务的类. 通过 _get_iterator() 方法返回 iterator, 对其调用 next() 得到 tensor.
  • Sampler, 数据集的采样策略, 给出每个 step 要使用的数据的索引 possibly_batched_index
  • Fetcher, 根据 possibly_batched_index, 从 dataset 对象中拿数据
  • collate_fn, Fetcher 对象拿到原始数据后, 调用 collate_fn 得到 tensor 对象, 送往模型.

一. Dataset

torch.utils.data.Dataset, 这是一个抽象类, 自己需要实现它的子类.

class Dataset(Generic[T_co]):
    def __getitem__(self, index) -> T_co:
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError

二. Sampler

torch.utils.data.Sampler, 也是一个抽象类.
默认的是 SequentialSampler + BatchSampler 的搭配.

class Sampler(Generic[T_co]):
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

class BatchSampler(Sampler[List[int]]):
	def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
		pass

    def __iter__(self) -> Iterator[List[int]]:
	    sampler_iter = iter(self.sampler)
	       while True:
	           try:
	               batch = [next(sampler_iter) for _ in range(self.batch_size)]
	               yield batch
	           except StopIteration:
	               break

三. collate_fn

一个接口, 完成基本类型数据到 batch tensor 的处理. 方法签名见下:

  • def my_collate_fn(feature_dict_list: List[Dict[str, Union[str, Tensor]]]) -> Dict[str, Tensor]
    • feature_dict_list. 元素个数为 batch_size, 元素为 Dict[str, Any], 通常为基本数据类型.
    • return: Dict[str, Tensor], tensor_.shape[0] 通常为相应的 batch_size.

Q: 如何额外传参?
方法签名中看到没有额外的传参设计, 那么我们想传一些参数配置(比如不同的特征处理规则), 想做到通用化, 要怎么办呢?
A: 传一个 callable 对象即可. 做法为自定义 MyCollator 类, init 方法传入配置, 并实现 __call__(self, xxx)方法, 签名与 collate_fn 保持一致即可.

四. DataLoader

class DataLoader(Generic[T_co]):
	def __init__(self, dataset: Dataset[T_co], 
					batch_size: Optional[int] = 1,  
					num_workers: int = 0, 
					collate_fn: Optional[_collate_fn_t] = None, 
					worker_init_fn: Optional[_worker_init_fn_t] = None,
					)
        if sampler is None:  # give default samplers
            if self._dataset_kind == _DatasetKind.Iterable:
                # See NOTE [ Custom Samplers and IterableDataset ]
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
                else:
                    sampler = SequentialSampler(dataset)  # type: ignore[arg-type]
        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

    def _next_index(self):
        return next(self._sampler_iter)	
        	
    def _get_iterator(self) -> '_BaseDataLoaderIter':
       if self.num_workers == 0:
           return _SingleProcessDataLoaderIter(self)
       else:
           self.check_worker_number_rationality()
           return _MultiProcessingDataLoaderIter(self)

依赖 dataset, 负责 batch, shuffle 等能力增强, 返回是 Tensor 对象.

  • DataLoader#__init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,...,collate_fn, ...)
    • collate_fn (callable, optional): merges a list of samples to form a
      mini-batch of Tensor(s).
      如果不传, 内部会赋值为torch.utils.data._utils.collate.default_collate, 已足够好用, 见下节例子.
      在 collate_fn 中可做灵活处理, 等价于 tf.dataset.map(map_fn).

4.1 _DatasetKind

class _DatasetKind(object):
    Map = 0
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

4.2 Fetcher

class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)
        self.ended = False

    def fetch(self, possibly_batched_index):
        if self.ended:
            raise StopIteration

        if self.auto_collation:
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    self.ended = True
                    break
            if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                raise StopIteration
        else:
            data = next(self.dataset_iter)
        return self.collate_fn(data)

4.3 _SingleProcessDataLoaderIter

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        return data

五. dataloader 多进程

5.1 主进程

_MultiProcessingDataLoaderIter

class _BaseDataLoaderIter(object):
	def __iter__(self) -> '_BaseDataLoaderIter':
        return self
        
    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def _next_data(self):
        raise NotImplementedError


class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
	def __init__(self, loader):
		self._worker_init_fn = loader.worker_init_fn
		self._worker_result_queue = multiprocessing_context.Queue()
		self._index_queues = []
		self._workers = []
		for i in range(self._num_workers):
			index_queue = multiprocessing_context.Queue()
            w = multiprocessing_context.Process(
                target=_utils.worker._worker_loop,
                args=(self._dataset_kind, self._dataset, index_queue,
                      self._worker_result_queue, self._workers_done_event,
                      self._auto_collation, self._collate_fn, self._drop_last,
                      self._base_seed, self._worker_init_fn, i, self._num_workers,
                      self._persistent_workers, self._shared_seed))
            w.daemon = True
            w.start()
            self._index_queues.append(index_queue)
            self._workers.append(w)
            self._reset(loader, first_iter=True)
		
    def _next_data(self):
        while True:
        	idx, data = self._get_data()
        	return self._process_data(data)

	def _get_data(self):
	     while True:
	         success, data = self._try_get_data()
	         if success:
	             return data

    def _process_data(self, data):
        self._rcvd_idx += 1
        self._try_put_index()
        if isinstance(data, ExceptionWrapper):
            data.reraise()
        return data
        
    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        # Tries to fetch data from `self._data_queue` once for a given timeout.
        # This can also be used as inner loop of fetching without timeout, with
        # the sender status as the loop condition.
        #
        # This raises a `RuntimeError` if any worker died expectedly. This error
        # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
        # (only for non-Windows platforms), or the manual check below on errors
        # and timeouts.
        #
        # Returns a 2-tuple:
        #   (bool: whether successfully get data, any: data if successful else None)
        try:
            data = self._data_queue.get(timeout=timeout)
            return (True, data)

5.2 子进程

def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
                 auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
                 num_workers, persistent_workers, shared_seed):
    try:
        global _worker_info
        _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
                                    seed=seed, dataset=dataset)
        if init_fn is not None:
            init_fn(worker_id)
            
        fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
        watchdog = ManagerWatchdog()
        while watchdog.is_alive():
            r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
            data = fetcher.fetch(index)
            data_queue.put((idx, data))
            del data, idx, index, r  # save memory
    except KeyboardInterrupt:
        # Main process will raise KeyboardInterrupt anyways.
        pass

六. IterableDataset 流式数据集

详见 参考[1].
典型场景是内存盛不下, 网络数据库 -> dataset -> model feed 流式运作.
例子见下.

import numpy as np
from torch.utils.data import IterableDataset, DataLoader


class StreamingDataset(IterableDataset):

    def generator(self):
        i = 0
        data = np.arange(0,10).reshape((5, 2))
        while True:
            if i == len(data):
                break
            yield {'sample_id': i, 'value': data[i]}
            i += 1

    def __iter__(self):
        return iter(self.generator())


def dataset_test():
    it = iter(StreamingDataset())
    print(next(it))
    print(next(it))


def loader_test():
    loader = DataLoader(StreamingDataset(), batch_size=2)
    it = iter(loader)
    print(next(it), next(it)) 

if __name__ == '__main__':
    loader_test()

"""
dataset_test()
{'sample_id': 0, 'value': array([0, 1])}
{'sample_id': 1, 'value': array([2, 3])}


loader_test()
{'sample_id': tensor([0, 1]), 'value': tensor([[0, 1],
        [2, 3]], dtype=torch.int32)}
{'sample_id': tensor([2, 3]), 'value': tensor([[4, 5],
        [6, 7]], dtype=torch.int32)}
"""

参考

todo

你可能感兴趣的:(torch,pytorch,python,深度学习)