需要 Dataset + collate_fn + Sampler + DataLoader 联用, 才等价于 tf 的 dataset.
_get_iterator
() 方法返回 iterator, 对其调用 next() 得到 tensor.torch.utils.data.Dataset
, 这是一个抽象类, 自己需要实现它的子类.
class Dataset(Generic[T_co]):
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __len__(self):
raise NotImplementedError
torch.utils.data.Sampler
, 也是一个抽象类.
默认的是 SequentialSampler + BatchSampler 的搭配.
class Sampler(Generic[T_co]):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
class BatchSampler(Sampler[List[int]]):
def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
pass
def __iter__(self) -> Iterator[List[int]]:
sampler_iter = iter(self.sampler)
while True:
try:
batch = [next(sampler_iter) for _ in range(self.batch_size)]
yield batch
except StopIteration:
break
一个接口, 完成基本类型数据到 batch tensor 的处理. 方法签名见下:
my_collate_fn
(feature_dict_list: List[Dict[str, Union[str, Tensor]]]) -> Dict[str, Tensor]
Q: 如何额外传参?
方法签名中看到没有额外的传参设计, 那么我们想传一些参数配置(比如不同的特征处理规则), 想做到通用化, 要怎么办呢?
A: 传一个 callable 对象即可. 做法为自定义 MyCollator 类, init 方法传入配置, 并实现 __call__(self, xxx)
方法, 签名与 collate_fn 保持一致即可.
class DataLoader(Generic[T_co]):
def __init__(self, dataset: Dataset[T_co],
batch_size: Optional[int] = 1,
num_workers: int = 0,
collate_fn: Optional[_collate_fn_t] = None,
worker_init_fn: Optional[_worker_init_fn_t] = None,
)
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
else:
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
def _next_index(self):
return next(self._sampler_iter)
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
依赖 dataset, 负责 batch, shuffle 等能力增强, 返回是 Tensor 对象.
DataLoader#__init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,...,collate_fn, ...)
torch.utils.data._utils.collate.default_collate
, 已足够好用, 见下节例子.class _DatasetKind(object):
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
self.ended = False
def fetch(self, possibly_batched_index):
if self.ended:
raise StopIteration
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
self.ended = True
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
return data
_MultiProcessingDataLoaderIter
class _BaseDataLoaderIter(object):
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
self._worker_init_fn = loader.worker_init_fn
self._worker_result_queue = multiprocessing_context.Queue()
self._index_queues = []
self._workers = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers, self._shared_seed))
w.daemon = True
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
self._reset(loader, first_iter=True)
def _next_data(self):
while True:
idx, data = self._get_data()
return self._process_data(data)
def _get_data(self):
while True:
success, data = self._try_get_data()
if success:
return data
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index()
if isinstance(data, ExceptionWrapper):
data.reraise()
return data
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
# Tries to fetch data from `self._data_queue` once for a given timeout.
# This can also be used as inner loop of fetching without timeout, with
# the sender status as the loop condition.
#
# This raises a `RuntimeError` if any worker died expectedly. This error
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
# (only for non-Windows platforms), or the manual check below on errors
# and timeouts.
#
# Returns a 2-tuple:
# (bool: whether successfully get data, any: data if successful else None)
try:
data = self._data_queue.get(timeout=timeout)
return (True, data)
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
num_workers, persistent_workers, shared_seed):
try:
global _worker_info
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
seed=seed, dataset=dataset)
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
data = fetcher.fetch(index)
data_queue.put((idx, data))
del data, idx, index, r # save memory
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
详见 参考[1].
典型场景是内存盛不下, 网络数据库 -> dataset -> model feed
流式运作.
例子见下.
import numpy as np
from torch.utils.data import IterableDataset, DataLoader
class StreamingDataset(IterableDataset):
def generator(self):
i = 0
data = np.arange(0,10).reshape((5, 2))
while True:
if i == len(data):
break
yield {'sample_id': i, 'value': data[i]}
i += 1
def __iter__(self):
return iter(self.generator())
def dataset_test():
it = iter(StreamingDataset())
print(next(it))
print(next(it))
def loader_test():
loader = DataLoader(StreamingDataset(), batch_size=2)
it = iter(loader)
print(next(it), next(it))
if __name__ == '__main__':
loader_test()
"""
dataset_test()
{'sample_id': 0, 'value': array([0, 1])}
{'sample_id': 1, 'value': array([2, 3])}
loader_test()
{'sample_id': tensor([0, 1]), 'value': tensor([[0, 1],
[2, 3]], dtype=torch.int32)}
{'sample_id': tensor([2, 3]), 'value': tensor([[4, 5],
[6, 7]], dtype=torch.int32)}
"""
todo