torch.utils.data中Dataset TensorDataset以及Dataloader

在打包自己处理的数据时有两种方法:
1.写个数据集的类(myDataset ),并继承Dataset
在myDataset 类中实现__len__()和__getitem__两个函数, __len__返回数据集的总长,__getitem__返回每次按照 索引所取的项,即x, y
比如:在处理序列问题时:
__len__返回的是:all_len/seq_len
__getitem__返回的是:一个输入序列,一个输出序列,即:x_seq, y_seq

    if index + input_seq_len + prediction_seq_len +1 < all_Data_Len:
        train = allData[j:j+input_seq_len,:,:]
        ...
        ...
    return train, label1, label2, ...

  1. 将按顺序处理好的数据集转换为Tensor类型(Torch.Tensor Torch.from_numpy()),并放到TensorDataset中
    注意:TensorDataset是按照第一维度进行索引的,故放进去的数据第一维度必须相同。例如,train_x和train_y的长度是必须相同的。
    在处理序列时,必须把seq_len长度的数据当成一个数据,train_x= allData[j:j+input_seq_len,:,:]相当于一个x,同理train_y也是如此。即trainDataX=[train_x0, train_x1, train_x2, train_x3, …],即至少是二维的。对应的train_y也是如此。
    然后
    myDataset = TensorDataset(trainDataX,trainDataY)

综上任意一种处理完毕后将处理后的数据集放入DataLoader,就可以在训练的时候直接用了

myloader = DataLoader(dataset=myDataset , batch_size=1, shuffle=False)

训练中:

for i, data in enumerate(train_loader):

torch.autograd为tensor的所有操作自动求导(Variable类是核心),所有Tensor必须转换为Variable

Dataset Sampler Dataloader

PyTorch数据加载模块一共涉及到Dataset,Sampler,Dataloader三个类

  1. Dataset负责对raw data source封装,将其封装成Python可识别的数据结构,其必须提供提取数据个体的接口。Dataset共有Map-style datasets和Iterable-style datasets两种:
    1.1 map-style dataset:实现了__getitem__和__len__接口,表示一个从索引/key到样本数据的map。比如:datasets[10],就表示第10个样本。
    1.2 iterable-style dataset:实现了__iter__接口,表示在data samples上的一个Iterable(可迭代对象),这种形式的dataset非常不适合随机存取(代价太高),但非常适合处理流数据。比如:iter(datasets)获得迭代器,然后不断使用next迭代从而实现遍历。

  2. Sampler负责提供一种遍历数据集所有元素索引的方式。

  3. Dataloader负责加载数据,同时支持map-style和iterable-style Dataset,支持单进程/多进程,还可以设置loading order, batch size, pin memory等加载参数。

总结一下步骤:

  1. 设置Dataset,将数据data source包装成Dataset类,暴露提取接口。
  2. 设置Sampler,决定采样方式。我们是能从Dataset中提取元素了,还是需要设置Sampler告诉程序提取Dataset的策略。
  3. 将设置好的Dataset和Sampler传入DataLoader,同时可以设置shuffle,batch_size等参数。使用DataLoader对象可以快捷方便地在给定数据集上遍历。

归纳一下:即Dataloader负责总的调度,命令Sampler定义遍历索引的方式,然后用索引去Dataset中提取元素。于是就实现了对给定数据集的遍历。

1.1 Dataset

所有设计的Dataset类必须继承torch.utils.data.Dataset这个类。

  1. 这些子类必须要实现方法__getitem__(),来支持可以给定一个key(即索引)来获取对应的数据样本
  2. 这些类可以实现方法__len__(),来返回数据集的大小规模
    Dataset实现非常简洁,就只是提供了__getitem__ 和 __add__这两个接口。

前者很重要,是Dataset及其子类的核心,定义了数据元素提取(即通过索引获取样本,实际代码中常使用[]输入索引)

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __add__(self, other):
        return ConcatDataset([self, other])

具体实践中,我们需要使用Dataset的子类,自己实现的或者现成的。

我们可以来看看PyTorch为我们提供的现成的Dataset子类

  • TensorDataset
  • IterableDataset
  • ConcatDataset
  • ChainDataset
  • Subset

下面着重介绍
TensorDatasetIterableDataset.

*CLASS torch.utils.data.TensorDataset(tensors)

包装了Tensor的Dataset子类,map-style dataset
每个样本可以通过tensors第一个维度的索引获取

class TensorDataset(Dataset):
    r"""
    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)
    def __len__(self):
        return self.tensors[0].size(0)

如上源码:
__init__的形参是*tensors,因此是可以传入多个tensor变量的,但需要保证每个tensor的第一个维度均是一样的。
例子

正确输入:(100*64*64*3,100*32*32*3100*16*16*31
错误输入:(100*64*64*3,200*32*32*3100*16*16*31

__getitem__提取的就是*tensors中每个张量的第index个样本(因为每个张量第一维度都是一样的)

__len__即*tensors每个张量第一个维度长度

常见用法:*tensors指定我们可以输入多个张量,我们可以同时输入train_data和train_label

dataset = TensorDataset(train_data, train_label)

CLASS torch.utils.data.IterableDataset
内部样本的组织形式是Iterable的所有dataset类都是IterableDataset类的子类,
即:所有iterable-style dataset都是IterableDataset的子类
这种形式的dataset对于处理流数据是非常有用的。
所有这些子类需要实现__iter__方法(而不是__getitem__方法了),需要据此来返回样本的迭代器,从而遍历dataset(实际代码中常使用iter+next来遍历)
关于Python中Iterable和Iterator的介绍见我的另一篇文章:刘昕宸:彻底搞懂Python的__iter__和__next__,Iterable和Iteration

class IterableDataset(Dataset[T_co]):
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

    def __add__(self, other: Dataset[T_co]):
        return ChainDataset([self, other])

关于多进程的问题:
IterableDataset的某个子类被DataLoader使用时,dataset中的每个item可以通过DataLoader的Iterator迭代获取。
当num_works>0时就是多进程模式,每个工作进程都有一个不同的dataset对象的拷贝,因此我们需要独立安排每一份拷贝该如何处理(后面会有例子),以防止不同的进程会返回重复的元素。(有MPI编程经验的同学应该更能理解!)
可以通过get_worker_info方法,在某一当前进程中调用,获得当前进程信息。这个方法要么在dataset类的__iter__方法中使用,要么在DataLoader的worker_init_fn方法中设置并使用。
举2个例子(来自官网文档):

例1:在dataset类的__iter__方法中使用get_worker_info方法,划分工作空间,获得当前进程id,并根据进程id分配其需要处理的工作空间

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # 单进程:一个进程处理全部样本
            iter_start = self.start
            iter_end = self.end
        else: # 多进程,在当前进程中
            # 划分工作空间
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))

具体使用:

>>> # 给定样本集range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # 单进程加载
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # 2个进程加载
>>> # 进程0负责[3, 4].  进程1负责[5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # 更多的进程
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

例2:先来个反例:不手动配置每个进程的工作空间的话,默认每个进程的工作空间是整个dataset,因此每个进程都会遍历一次整个数据集,导致产生重复数据。

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # 给定样本集range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # 单进程加载
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # 直接多进程加载会产生重复数据
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

除了上面在MyIterableDataset的__iter__方法中依靠get_work_info分配工作空间,还可以事先定义函数worker_init_fn分配工作空间(分配策略与例1完全一致),再将该函数传给dataloader生效:

>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # 多进程加载,使用自定义的`worker_init_fn`
>>> # 进程0负责[3, 4].  进程1负责[5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # 更多的进程
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]

你以为这就完了吗???当然不!!!
贴心的PyTorch小可爱还为我们提供了计算机视觉常用的数据集,并将它们包装成了Dataset!!!

这些数据集都在torchvision.datasets下,共有这么多:

torch.utils.data中Dataset TensorDataset以及Dataloader_第1张图片

我们以CIFAR-10数据集为例来看一看:

CLASS torchvision.datasets.CIFAR10

使用举例:

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor(),
        normalize,
    ]), download=True),
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)

Sampler

CLASS torch.utils.data.Sampler(data_source: Optional[collections.abc.Sized])
所有Samplers的基类

Sampler的所有子类都需要实现__iter__,用来提供遍历dataset索引的方式。我们获得不同的索引遍历,就能以不同的方式遍历dataset,这就是samplers的目的。
PyTorch为我们提供了几种现成的Sampler子类:

SequentialSampler
RandomSampler
SubsetRandomSampler
WeightedRandomSampler
BatchSampler
DistributedSampler
下面我着重介绍一下SequentialSampler,RandomSampler和BatchSampler

CLASS SequentialSampler(Sampler[int])

SequentialSampler指定总是按照相同的次序,顺序地采样元素

关注方法__iter__,直接range生成顺序的索引,也就是为dataloader提供了顺序遍历dataset的方式。

class SequentialSampler(Sampler[int]):
    r"""
    Arguments:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

CLASS torch.utils.data.RandomSampler

RandomSampler提供了随机采样元素的方式。

如果replacement==False,则随机采样整个数据集,即num_samples==len(dataset)。此时sampler提供给dataloader以一种随机的次序遍历dataset.

如果replacement==True,则从数据集中随机采样num_samples个样本

仅贴出__iter__实现:

@property
def num_samples(self) -> int:
    # dataset size might change at runtime
    if self._num_samples is None:
        return len(self.data_source)
    return self._num_samples

def __iter__(self):
    n = len(self.data_source)
    if self.generator is None:
        generator = torch.Generator()
        generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
    else:
        generator = self.generator
    if self.replacement:
        for _ in range(self.num_samples // 32):
            yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
        yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
    else:
        yield from torch.randperm(n, generator=self.generator).tolist()

def __len__(self):
    return self.num_samples

CLASS torch.utils.data.BatchSampler

BatchSampler包装另一个sampler(输入参数),用来产生一个mini-batch大小的索引,相当于是为dataloader提供了提取dataset的1个mini-batch样本的索引。

关注__iter__和__len__方法:

class BatchSampler(Sampler[List[int]]):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        # Can only be called if self.sampler has __len__ implemented
        # We cannot enforce this condition, so we turn off typechecking for the
        # implementation below.
        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # type: ignore
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore

DataLoader

铺垫了这么多,终于讲到DataLoader了。

在训练/测试深度学习网络的程序中,我们直接遍历Dataloader来获取数据(data,label等),并将数据feed给网络用于前向传播和反向传播。

代码形如:

for data, label in train_loader:
    data, label = data.to(device), label.to(device).squeeze()
    opt.zero_grad()
    logits = model(data)
    loss = criterion(logits, label)

那么在for data, label in train_loader这个过程中究竟发生了什么呢?一起探索!

for循环会调用dataloader iter

以此获得迭代器来遍历dataset

def __iter__(self) -> '_BaseDataLoaderIter':
    # When using a single worker the returned iterator should be
    # created everytime to avoid reseting its state
    # However, in the case of a multiple workers iterator
    # the iterator is only created once in the lifetime of the
    # DataLoader object so that workers can be reused
    if self.persistent_workers and self.num_workers > 0:
        if self._iterator is None:
            self._iterator = self._get_iterator()
        else:
            self._iterator._reset(self)
        return self._iterator
    else:
        return self._get_iterator()

其中调用了self._get_iterator()获得迭代器:

def _get_iterator(self) -> '_BaseDataLoaderIter':
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        return _MultiProcessingDataLoaderIter(self)

为了简单起见,我们只考虑单进程的代码,那我们看一下_SingleProcessDataLoaderIter的实现:

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

_SingleProcessDataLoaderIter继承自_BaseDataLoaderIter,因此_BaseDataLoaderIter的代码也需要看一下:

class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        self._prefetch_factor = loader.prefetch_factor
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
        self._persistent_workers = loader.persistent_workers
        self._num_yielded = 0

    def __iter__(self) -> '_BaseDataLoaderIter':
        return self

    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        self._IterableDataset_len_called = loader._IterableDataset_len_called

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def _next_data(self):
        raise NotImplementedError

    def __next__(self) -> Any:
        if self._sampler_iter is None:
            self._reset()
        data = self._next_data()
        self._num_yielded += 1
        if self._dataset_kind == _DatasetKind.Iterable and \
                self._IterableDataset_len_called is not None and \
                self._num_yielded > self._IterableDataset_len_called:
            warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                        "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                              self._num_yielded)
            if self._num_workers > 0:
                warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                             "IterableDataset replica at each worker. Please see "
                             "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
            warnings.warn(warn_msg)
        return data

    next = __next__  # Python 2 compatibility

    def __len__(self) -> int:
        return len(self._index_sampler)

    def __getstate__(self):
        # TODO: add limited pickling support for sharing an iterator
        # across multiple threads for HOGWILD.
        # Probably the best way to do this is by moving the sample pushing
        # to a separate thread and then just sharing the data queue
        # but signalling the end is tricky without a non-blocking API
        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

dataloader获得了迭代器之后,我们的for循环需要调用__next__来获得下一个对象,从而实现遍历。

我们看一下_BaseDataLoaderIter的__next__:

def __next__(self) -> Any:
    if self._sampler_iter is None:
        self._reset()
    data = self._next_data()
    self._num_yielded += 1
    if self._dataset_kind == _DatasetKind.Iterable and \
            self._IterableDataset_len_called is not None and \
            self._num_yielded > self._IterableDataset_len_called:
        warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                    "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                            self._num_yielded)
        if self._num_workers > 0:
            warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                            "IterableDataset replica at each worker. Please see "
                            "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
        warnings.warn(warn_msg)
    return data

__next__需要调用_next_data,因此我们还需要看一下_SingleProcessDataLoaderIter的_next_data:

_next_data需要_next_index获得索引,并通过索引fetch到对应的样本。

def _next_data(self):
    index = self._next_index()  # may raise StopIteration
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    if self._pin_memory:
        data = _utils.pin_memory.pin_memory(data)
    return data

关于_next_index:

_sampler_iter来自_index_sampler,来自loader

def _next_index(self):
    return next(self._sampler_iter)  # may raise StopIteration

再看dataloader中的_index_sampler,一切就明白了:

@property
def _index_sampler(self):
    # The actual sampler used for generating indices for `_DatasetFetcher`
    # (see _utils/fetch.py) to read data at each time. This would be
    # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
    # We can't change `.sampler` and `.batch_sampler` attributes for BC
    # reasons.
    if self._auto_collation:
        return self.batch_sampler
    else:
        return self.sampler

总结来说就是dataloader提供了sampler,然后_SingleProcessDataLoaderIter迭代sampler获得索引。

下面我们来看看Fetch:

pytorch在Dataset上又封装了一层Fetcher。

这样做是使得iterable Dataset(对应_IterableDatasetFetcher)和map Dataset(对应_MapDatasetFetcher)在Dataloader内能使用相同的接口fetch,代码更加简洁。

fetcher需要index获取数据元素。

针对map-style fetcher:
关注fetch方法:直接输入索引index,作为map的key,获得对应的样本(即value)

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

注意:这里的index可能不只是一个索引,而是一个batch的索引。

这取决于_auto_collation,_auto_collation的取值在Dataloader中定义:

有batch_sampler,_auto_collation就为True,就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引。

@property
def _auto_collation(self):
    return self.batch_sampler is not None

针对iterable-style fetcher:
__init__方法内设置了dataset初始的迭代器
fetch方法内获取元素,index其实已经没有多大作用了。
对于batch_sampler(即auto_collation==True):直接使用往后遍历并提取len(possibly_batched_index)个样本(即1个batch的样本)
对于sampler:直接往后遍历并提取1个样本

class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    break
            if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                raise StopIteration
        else:
            data = next(self.dataset_iter)
        return self.collate_fn(data)

另外对python中Iterable,Iterator,iter,__next__等的详细解释,参见我另一篇文章:彻底搞懂Python的__iter__和__next__,Iterable和Iteration

最后,我们通过索引传入fetcher,fetch得到想要的样本!

我们的目标终于实现了!!!

整个过程调用关系总结:

loader.iter–> _get_iterator --> _SingleProcessDataLoaderIter --> _BaseDataLoaderIter --> next --> _next_data–> self._dataset_fetcher.fetch(index) --> _next_index -->_sampler_iter --> loader._index_sampler

但愿这么细致的讲解,能真正搞清楚Dataset,Sampler,DataLoader三者的机理及其运行关系。

总结:
Dataset封装数据集(可通过索引获取元素)

Sampler提供索引次序(可迭代,用于遍历)

DataLoader是一个调度器,迭代DataLoaderIter的过程中,迭代Sampler获得下一索引,并通过该索引使用fetcher(fetcher是对dataset的封装,使得dataloader代码与iterable-style/map-style dataset解耦)获得对应元素。

2.1 实战建议
Dataset
通常使用TensorDataset,或者我们自行实现一个其继承类。

Sampler
我们一般不用管,直接使用DataLoader默认指定的就行:

if sampler is None:  # give default samplers
    if self._dataset_kind == _DatasetKind.Iterable:
        # See NOTE [ Custom Samplers and IterableDataset ]
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            # Cannot statically verify that dataset is Sized
            # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
            sampler = RandomSampler(dataset, generator=generator)  # type: ignore
        else:
            sampler = SequentialSampler(dataset)

if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)

如上来自DataLoader的代码,

如果是iterable-style dataset,默认使用_InfiniteConstantSampler:
其实这个_InfiniteConstantSampler啥也没干,因为我们遍历iterable-style dataset依靠的是迭代器,根本就不需要索引!(上面介绍的_IterableDatasetFetcher已经说明了这一点!)

class _InfiniteConstantSampler(Sampler):
    r"""
    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self):
        super(_InfiniteConstantSampler, self).__init__(None)

    def __iter__(self):
        while True:
            yield None

如果是map-style dataset,有shuffle则默认使用RandomSampler;没有shuffle则默认使用SequentialSampler
batch_sampler就是对上面已经生成的sampler,进一步包装。
DataLoader
直接用就完事了!

2.2 具体例子
这是来自于DGCNN的PyTorch版本官方实现:WangYueFt/dgcnn

DGCNN是非常著名的点云特征学习网络,感兴趣的朋友可以参考我这一篇文章的解读:搞懂DGCNN,这篇就够了!论文及代码完全解析
自己实现Dataset,用于装载ModelNet40数据集:

torch.utils.data中Dataset TensorDataset以及Dataloader_第2张图片

class ModelNet40(Dataset):
    def __init__(self, num_points, partition='train'):
        self.data, self.label = load_data(partition)
        self.num_points = num_points
        self.partition = partition        

    def __getitem__(self, item):
        pointcloud = self.data[item][:self.num_points]
        label = self.label[item]
        if self.partition == 'train':
            pointcloud = translate_pointcloud(pointcloud)
            np.random.shuffle(pointcloud)
        return pointcloud, label

    def __len__(self):
        return self.data.shape[0]

将ModelNet40装载至DataLoader:

Sampler使用默认的,因为shuffle==True,因此使用的应该是RandomSampler:

train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=8,
                          batch_size=args.batch_size, shuffle=True, drop_last=True)

使用DataLoader:

直接for循环遍历就完事了:使用DataLoader:

直接for循环遍历就完事了:

for data, label in train_loader:
    data, label = data.to(device), label.to(device).squeeze()
    data = data.permute(0, 2, 1)
    batch_size = data.size()[0]
    opt.zero_grad()
    logits = model(data)
    loss = criterion(logits, label)
    loss.backward()
    opt.step()
    preds = logits.max(dim=1)[1]
    count += batch_size
    train_loss += loss.item() * batch_size
    train_true.append(label.cpu().numpy())
    train_pred.append(preds.detach().cpu().numpy())

3 参考资料
torch.utils.data - PyTorch 1.7.0 documentation

你可能感兴趣的:(pytorch学习)