Pytorch学习小记1:torch.utils.data.Dataset类和datasat.py文件的初读笔记

昨天在使用torch.utils.data.DataLoader类时遇到了一些问题,通过粗略的学习了解到了Pytorch的数据处理主要基于三个类:Dataset,DatasetLoader和DatasetLoaderIter,并且它们依次构成封装关系。关于他们之间的关系和解读,这篇文章总结得比较好理解:https://zhuanlan.zhihu.com/p/30934236

于是追根溯源找到了定义Dataset类的文档dataset.py学习了一下,顺便恶补一下我一个月突击学习的python编程基础...

1. 初读dataset.py

通读了一下dataset.py文件里的注释,了解到它定义了Dataset类,一个函数random_split()以及它的四个子类IterableDataset类、TensorDataset类、ConcatDataset类和Subset类;其中IterableDataset类又衍生出一个子类ChainDataset类。这里仅对他们做一个简要的介绍,主要理解各个类的作用是什么,以及注释表明的注意事项,具体细节日后再做学习。

1.1 Dataset类

Dataset类的描述如下:

r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

Dataset类是一个描述数据集的抽象类,可供我们定义自己的数据集。只要是以“标签-数据”保存的数据集,要想利用pytorch框架进行处理,都需要继承Dataset类来定义,原文的描述是:所有以键和数据样本构成映射的数据集都需要继承Dataset类。

要想通过继承定义自己的数据集,需要改写数据集中的成员函数__getitem__()以构建键->样本的索引形式。同样可以改写成员函数__len__()用来实现返回自己定义的数据集的大小

1.2 IterableDataset类

IterableDataset类是一个可迭代的Dataset类,定义中与Dataset不同的是它的__add__()返回值返回了一个ChainDataset()数据集,而Dataset返回的是ContactDataset().当数据集的数据样本需要做迭代处理的时候,需要继承IterableDataset类,尤其是对数据流的处理非常有用。

当基于IterableDataset定义的数据集被Dataloader处理的时候,Dataloader会为它的成员产生一个迭代器。(这也解开了之前我对TensorDataset被Dataloader处理之后进行迭代时报错的疑惑,由于TensorDataset直接继承Dataset类,不可迭代,所以对Dataloader迭代操作的时候会报错TypeError: 'DataLoader' object is not an iterator)

此外,注释还对多线程处理数据进行了补充说明,给了两个例子,还是直接上代码比较清楚:

Example 1: splitting workload across all workers in :meth:`__iter__`::

        >>> class MyIterableDataset(torch.utils.data.IterableDataset):
        ...     def __init__(self, start, end):
        ...         super(MyIterableDataset).__init__()
        ...         assert end > start, "this example code only works with end >= start"
        ...         self.start = start
        ...         self.end = end
        ...
        ...     def __iter__(self):
        ...         worker_info = torch.utils.data.get_worker_info()
        ...         if worker_info is None:  # single-process data loading, return the full iterator
        ...             iter_start = self.start
        ...             iter_end = self.end
        ...         else:  # in a worker process
        ...             # split workload
        ...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
        ...             worker_id = worker_info.id
        ...             iter_start = self.start + worker_id * per_worker
        ...             iter_end = min(iter_start + per_worker, self.end)
        ...         return iter(range(iter_start, iter_end))
        ...
        >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
        >>> ds = MyIterableDataset(start=3, end=7)

        >>> # Single-process loading
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
        [3, 4, 5, 6]

        >>> # Mult-process loading with two worker processes
        >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
        [3, 5, 4, 6]

        >>> # With even more workers
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
        [3, 4, 5, 6]

    Example 2: splitting workload across all workers using :attr:`worker_init_fn`::

        >>> class MyIterableDataset(torch.utils.data.IterableDataset):
        ...     def __init__(self, start, end):
        ...         super(MyIterableDataset).__init__()
        ...         assert end > start, "this example code only works with end >= start"
        ...         self.start = start
        ...         self.end = end
        ...
        ...     def __iter__(self):
        ...         return iter(range(self.start, self.end))
        ...
        >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
        >>> ds = MyIterableDataset(start=3, end=7)

        >>> # Single-process loading
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
        [3, 4, 5, 6]
        >>>
        >>> # Directly doing multi-process loading yields duplicate data
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
        [3, 3, 4, 4, 5, 5, 6, 6]

        >>> # Define a `worker_init_fn` that configures each dataset copy differently
        >>> def worker_init_fn(worker_id):
        ...     worker_info = torch.utils.data.get_worker_info()
        ...     dataset = worker_info.dataset  # the dataset copy in this worker process
        ...     overall_start = dataset.start
        ...     overall_end = dataset.end
        ...     # configure the dataset to only process the split workload
        ...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
        ...     worker_id = worker_info.id
        ...     dataset.start = overall_start + worker_id * per_worker
        ...     dataset.end = min(dataset.start + per_worker, overall_end)
        ...

        >>> # Mult-process loading with the custom `worker_init_fn`
        >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
        [3, 5, 4, 6]

        >>> # With even more workers
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
        [3, 4, 5, 6]

这段程序定义可迭代类MyIterableDataset用于实现输入起止值,保存起止值之间的所有自然数(包括开始数但不包括结束的数)。如果不像Example1 一样在成员函数__iter__中定义多线程处理的方案的话,直接丢给Dataloader处理的时候,如果定义多线程,会把全部数据存入各个指定的线程处理,那么结果就会重复输出(见Example2 中定义worker_init_fn()函数之前的代码)。

1.3 TensorDataset类

TensorDataset类专门用来存储张量,它可以沿第一维度的数据索引张量的每个样本,所以我们载入的张量的第一维数据的大小就必须要相同才行。

这也解释了为什么之前学习过程中载入数据时,要先对数据做一个torch.unsqueeze()处理了。当数据只有一维时,是没法通过一维数据索引的(因为仅有的一个维度就是样本本身),程序就会报错。

1.4 ConcatDataset类

ConcatDataset类是一个串联数据集,可以用来组合不同的现有数据集。具体实现办法可以从构造函数中看出:如果传入的数据集非空->那么就把数据集封装进一个列表->传入成员self.dataset->确保成员不可迭代之后->更新数据集大小。从构造函数中也可以看出,这个类同样不支持迭代。(同样解释了IterableDataset类和Dataset类的__add__()函数调用的类不一样的原因)

看代码就可以大概推测出这个类主要就是用于拼接不同数据集的。

1.5 ChainDataset类

理解了ConcatDataset,同样可以类比出ChainDataset的大概含义。ChainDataset就是用来链接多个可迭代数据集(IterableDataset类及其派生类)的。值得一提的是,链接操作是即时完成的,所以适用于处理大规模的数据流。构造时需要输入待链接的数据集,很好理解。

1.6 Subset类

Subset类用于存储数据集和索引。构造参数为数据集(母集)及目标索引,该类的成员函数__getitem__()可以根据输入索引返回索引对应的数据集。

1.7 random_split()函数

用于随机将数据集拆分为给定长度的非重叠数据集。通过使用函数加深一下理解:

# 制作伪数据用于实验
x = torch.unsqueeze(torch.linspace(-1, 1, 10), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))
print(x)
print(y)

# 载入torchDataset
torch_dataset = Data.TensorDataset(x, y)

# 制作迭代用的列表,含义是:每次随机抽取两个样本长度的数据组成队列
list1 = []
for i in range(torch_dataset.__len__()//2):
    list1.append(2)

# 随即拆分并打印结果
subset = Data.random_split(torch_dataset, list1)
for i in range(subset.__len__()):
    print(subset.__getitem__(i)[1])

输出结果是:

tensor([[-1.0000],
        [-0.7778],
        [-0.5556],
        [-0.3333],
        [-0.1111],
        [ 0.1111],
        [ 0.3333],
        [ 0.5556],
        [ 0.7778],
        [ 1.0000]])
tensor([[1.0000],
        [0.6049],
        [0.3086],
        [0.1111],
        [0.0123],
        [0.0123],
        [0.1111],
        [0.3086],
        [0.6049],
        [1.0000]])
(tensor([-0.1111]), tensor([0.0123]))
(tensor([-0.5556]), tensor([0.3086]))
(tensor([0.5556]), tensor([0.3086]))
(tensor([0.7778]), tensor([0.6049]))
(tensor([-0.7778]), tensor([0.6049]))

输入数据中,x是-1到1之间均等分割出来的10个数,y是x对应数的平方,程序把x和y组合到了TensorDataset里,并进行随机分割,分割成了五组数据集,并且能看出来是随机分割的。

需要补充说明的是,random_split()函数对输入数据要求非常苛刻,分割索引必须以可迭代对象表示,并且对象所有的成员之和必须等于输入数据集的和,否则会报错:

ValueError: Su(subset[1])m of input lengths does not equal the length of the input dataset!

2. torch.utils.data.Dataset类学习

通读一遍dataset.py之后对Dataset类的作用和基本用法有了大概的了解,接下来对源码进行学习并做总结。

2.1 基本用法

这部分看一下pytorch文档就非常清楚了:https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/

Dataset参数:

  • data_tensor (Tensor) - 包含样本数据
  • target_tensor (Tensor) - 包含样本目标(标签)

此外还可以基于Dataset类自定义自己的数据集,具体的方法有博文描述得非常清楚了,这里不再赘述:https://blog.csdn.net/u012436149/article/details/69061711

2.2 代码学习

Dataset类得定义代码非常简短:

class Dataset(object):


    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

首先它是一个类,关于定义的时候为什么需要继承“object”我还专门去了解了一下,据说是因为python2版本的遗留问题:当一个类B继承了母类A,并且又派生出了子类C,并且类A和B分别定义了同一个函数func()时,继承了“全部家当”的C类在调用函数func()时会产生歧义:时调用A.func()还是B.func()呢?经典的类(没有继承object)会采用深度优先的搜索策略去调用A.func();而新式类(object)会采用广度优先的搜索策略调用B.func()。基本就是这么一个差异,更详细的分析可以参考:https://www.zhihu.com/question/19754936

其次关于成员函数__getitem__()的理解,这个函数主要是根据索引来返回索引所指向的数据样本的。目前函数体只有raise 语句,说明这个函数不能直接使用,在构造自己的数据集时需要根据自己的数据类型自定义__getitem__()函数,否则程序在执行这一语句时会报错NotImplementedError.

具体的使用我用TensorDataset尝试了一下(因为TensorDataset类时Dataset类的子类,并且已经定义好了自己的__getitem__()函数),代码如下:

# 制作伪数据用于实验
x = torch.unsqueeze(torch.linspace(-1, 1, 10), dim=1)
y = x.pow(2)


# 载入torchDataset
torch_dataset = Data.TensorDataset(x, y)

# 展示数据内容
print(torch_dataset)
for i in range(torch_dataset.__len__()):
    print(torch_dataset.__getitem__(i))

数据还是之前测试random_split()的数据,输出结果为:


(tensor([-1.]), tensor([1.]))
(tensor([-0.7778]), tensor([0.6049]))
(tensor([-0.5556]), tensor([0.3086]))
(tensor([-0.3333]), tensor([0.1111]))
(tensor([-0.1111]), tensor([0.0123]))
(tensor([0.1111]), tensor([0.0123]))
(tensor([0.3333]), tensor([0.1111]))
(tensor([0.5556]), tensor([0.3086]))
(tensor([0.7778]), tensor([0.6049]))
(tensor([1.]), tensor([1.]))

第一行是print(torch_dataset)的数据,可见直接输出只会给出数据集的地址;而采用成员函数__getitem__()输出便可得到对应的数据。其实还有一个更简单的实现方法,就是直接用数组的形式索引:

for i in range(torch_dataset.__len__()):
    print(torch_dataset[i])

结果是一样的~

此外还有一个有趣的地方是,在实验random_split()函数时用到了subset类的实例subset,如果将输出语句

print(subset.__getitem__(i)[1])改为
print(subset.__getitem__(i))

则会导致输出结果变成:





没错,是一堆地址,说明subset并没有直接存储分好的子数据集,而是保存了母数据集以及子集的索引,我们__getitem__()得到的是"保存索引的数据集",代码表示就是:

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

感觉动手实践一下豁然开朗了!

你可能感兴趣的:(Pytorch学习小记1:torch.utils.data.Dataset类和datasat.py文件的初读笔记)