Dataset类在torch.util.data里定义,所以引用方式为from torch.util.data import Dataset
Dataset类定义的操作需要完成:对单个样本完成读取,以及某些可能进行的预处理
对于Dataset类,我们需要完成三个方法:__init__,__getitem__,__len__
方法名 | 作用 |
---|---|
__init__(self, *loader_args, **loader_kwargs) | 完成Dataset类的初始化 |
__getitem__(self, index) | 基于索引返回某个样本(sample, label) |
__len__(self) | 返回所有样本个数 |
以covid19数据集加载举例(LiHongYee,MLSpring2022HW1)
class COVID19Dataset(Dataset):
def __init__(self,
covid_features,
covid_labels,
select_features=None,
select_features_model=None):
self.covid_features = np.array(covid_features)
self.covid_labels = covid_labels
self.select_features = select_features
self.select_features_model = select_features_model
if select_features is not None and select_features_model is not None:
self.covid_features = self.select_features_model.transform(self.covid_features)
self.covid_features = torch.from_numpy(self.covid_features).float()
if self.covid_labels is not None:
self.covid_labels = torch.from_numpy(np.array(self.covid_labels)).float()
self.input_dim = self.covid_features.shape[1]
def __getitem__(self, index):
if self.covid_labels is None:
return self.covid_features[index]
else:
return self.covid_features[index], self.covid_labels[index]
def __len__(self):
return len(self.covid_features)
DataLoader在torch.util.data里定义,所以引用方式为from torch.util.data import DataLoader
DataLoader类定义的操作需要完成:将Dataset里的单个样本处理成mini batch
对于DataLoader类,如果要自定义,则一般需要完成__init__和__len__方法。如果无需更多配置,则将自定义的Dataset类传入DataLoader即可
dataset (Dataset): dataset from which to load the data.
自定义的Dataset
batch_size (int, optional): how many samples per batch to load (default: ``1``).
mini batch的大小,通常把batch_size改大一点,为2的整数次幂
shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: ``False``).
在每轮训练后,将数据集打乱
sampler (Sampler or Iterable, optional): defines the strategy to draw samples from the dataset. Can be any ``Iterable`` with ``__len__`` implemented. If specified, :attr:`shuffle` must not be specified.
自定义方法(某种顺序)从Dataset中取样本,指定这个参数就不能设置shuffle
指定shuffle相当于使用内置的RandomSampler进行采样,否则使用SequentialSampler
RandomSampler的__iter__方法有一行代码:yield from torch.randperm(n, generator=self.generator).tolist()
SequentialSampler: return iter(range(len(self.data_source))),均继承了Sampler[int]
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
返回一个batch的索引,与batch_size, shuffle, sampler, drop_last互斥
传入了batch_sampler,相当于已经告诉了PyTorch如何从Dataset取多少数据,怎么取数据去组成一个mini batch,所以不需要以上参数
num_workers (int, optional): how many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. (default: ``0``)
多进程加载数据,默认为0,即采用主进程加载数据
collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
聚集函数,用来对一个batch进行后处理,拿到一个batch的数据后进行什么处理,用这个参数定义,返回处理后的batch数据
常用默认:_utils.collate.default_collate,源码中进行了若干逻辑判断,仅将数据组合起来返回,没有实质性工作
默认collate_fn的声明是:def default_collate(batch): 所以自定义collate_fn需要以batch为输入,以处理后的batch为输出
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, see the example below.
用于将tensor加载到GPU中进行运算
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``)
是否保存最后一个mini batch,样本数量可能不支持被batch size整除,所以drop_last参数决定是否保留最后一个可能批量较小的batch
timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: ``0``)
控制从进程中获取一个batch数据的时延
worker_init_fn (callable, optional): If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: ``None``)
初始化子进程
prefetch_factor (int, optional, keyword-only arg): Number of samples loaded in advance by each worker. ``2`` means there will be a total of 2 * num_workers samples prefetched across all workers. (default: ``2``)
控制样本在每个进程里的预加载,默认为2
persistent_workers (bool, optional): If ``True``, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. (default: ``False``)
控制加载完一次Dataset是否保留进程,默认为False
在DataLoader的__init__函数里,我们可以看到,它实现了:
默认参数逻辑:
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
self.generator = generator
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
而self._dataset_kind == _DatasetKind.Iterable
是在Dataset类是IterableDataset时才为True
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
IterableDataset应用于数据集非常大,将其完全加载进内存不现实(例如高达几个TB的数据),这时就需要IterableDataset构建可迭代的Dataset类,自定义的Dataset需要继承自torch.util.data.IterableDataset,重写__iter__方法,返回可迭代对象(通常是yield生成器)
所以,对于IterableDataset来说,就没有构建采样器Sampler的需求,因为样本是通过调用__iter__一个个读取出来的。执行封装的DataLoader传进去的batch_size次__iter__方法,就获取到一个mini batch
IterableDataset对应的_InfiniteConstantSampler为:
class _InfiniteConstantSampler(Sampler):
r"""Analogous to ``itertools.repeat(None, None)``.
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
Args:
data_source (Dataset): dataset to sample from
"""
def __init__(self):
super(_InfiniteConstantSampler, self).__init__(None)
def __iter__(self):
while True:
yield None
可以看到,__iter__方法返回None的生成器
所以,对于自定义的Dataset,如果shuffle为True,调用RandomSampler,否则为SequentialSampler
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
is supposed to be specified only when `replacement` is ``True``.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if self._num_samples is not None and not replacement:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
yield from torch.randperm(n, generator=generator).tolist()
def __len__(self) -> int:
return self.num_samples
我们主要关注__iter__方法,可以看到:
n为数据集大小
如果指定了replacement参数为True,则需要指定num_samples参数,表示采样器需要返回的样本个数
PyTorch源码里通过torch.randint以32为一批返回0~n-1的随机整数,每一批共计32个采样下标,共采样num_samples // 32批,最后一批的采样下标数为num_samples对32取余,所以最后的采样下标数总和为num_samples
如果保持默认的replacement参数为False,则通过torch.randperm(n)返回0~n-1的随机序列,共计n个采样下标
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
SequentialSampler的__iter__方法返回顺序迭代器,每次调用__iter__方法即可返回顺序下标
class BatchSampler(Sampler[List[int]]):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[List[int]]:
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
BatchSampler需要传入一个其他的Sampler,用以将该Sampler生成的采样下标组装成mini batch的采样下标
我们重点关注BatchSampler的__iter__方法,可以看到:通过for循环调用sampler的iter方法,拿到一个采样下标放入batch列表里,直到batch列表的长度等于指定的batch size,返回batch对应的生成器,随后重置batch列表为空,再接着从sampler里继续取采样下标。如果drop_last为False并且最后一个batch有样本的话,就把最后一个不满batch size的采样下标生成器返回
__len__方法返回总共的batch数,即所有的样本被分成了多少个batch
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
default_collate大部分都在做合理性判断的工作,实质上是把所有相关的数据转换成tensor,把Dataset的__getitem__的对应数据组装后返回。例如:[(img0, label0), (img1, label1),(img2, label2), ] 整理成[[img0,img1,img2,], [label0,label1,label2,]],这里要求多个img的size相同(根据isinstance(elem, collections.abc.Sequence可以看出这就是为什么遍历DataLoader时,我们拿到的是列表数据)
collate_fn是对一个batch的数据做后处理,即结合BatchSampler给的mini batch采样下标,利用Dataset里的__getitem__(self, index)方法,取出一个batch的数据,然后传到collate_fn里进行处理。为了摸清collate_fn的运行机制,我们先去DataLoader源码的__iter__方法里看它是怎么取数据的
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'.
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
可以看到,这里DataLoader在__iter__里调用_get_iterator方法创建迭代器,所以我们再去阅读_get_iterator方法
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
我们先关注默认情况,即num_workers=0时,_get_iterator方法返回了一个_SingleProcessDataLoaderIter实例,而这个_SingleProcessDataLoaderIter实例继承自_BaseDataLoaderIter这个基类,_BaseDataLoaderIter类里实现了__iter__方法和__next__方法,用于对这个迭代器遍历取数据
可迭代对象实现了__iter__方法,支持重复遍历,但不支持next(可迭代对象),而迭代器不支持重复遍历,采用iter(可迭代对象)获取对应的迭代器,这时可以对其使用next方法。迭代器如果实现了__next__方法,就可以使用next(迭代器)返回迭代器的下一个值
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._prefetch_factor = loader.prefetch_factor
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
next = __next__ # Python 2 compatibility
def __len__(self) -> int:
return len(self._index_sampler)
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
在_SingleProcessDataLoaderIter里,我们可以看到:这个迭代器又实例化了一个数据获取器_dataset_fetcher,而这个数据获取器接收了collate_fn参数
_SingleProcessDataLoaderIter实现了_next_data方法,它先调用_next_index方法获取下一批采样下标,而这批采样下标就是从_BaseDataLoaderIter基类里的_next_index方法获得的,该方法调用了之前BatchSampler的迭代器(auto_collation为True的默认情况,因为DataLoader设置batch_size不为None时会创建BatchSampler,然后将_index_sampler设置为BatchSampler)获取下一批次采样下标
可以看到,在基类的__next__方法调用了_next_data方法获取下一批次数据
_SingleProcessDataLoaderIter的_next_data方法调用的是数据获取器_dataset_fetcher的fetch方法
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
而这里的fetch方法,如果auto_collation为True(设置了batch_size,自动创建了BatchSampler),就根据下一批的采样下标,从dataset里根据__getitem__组装数据,返回组装后的列表;否则,就依据Sampler(auto_collation为False时,前面的_index_sampler就为Sampler)的迭代器给出的单个采样下标,取dataset的一条数据
至此,我们终于见到collate_fn在此处被调用,这也明确了collate_fn确实起到了取出批次数据之后的处理作用
collate_fn输入数据在auto_collation为True时是一个列表,列表里的每个元素是Dataset的__getitem__返回的值,在auto_collation为False时,是Dataset的__getitem__返回的单条样例的数据类型
参考资源:
https://blog.csdn.net/weixin_35757704/article/details/119715900
https://www.daimajiaoliu.com/daima/4ede05ecd1003fc
https://blog.csdn.net/mieleizhi0522/article/details/82142856/