【pytorch】学会pytorch dataloader数据加载(一)

DataLoader

Dataloader可以将自己的数据装换成Tensor,然后有效的迭代数据。可以很有效的简化数据的读取过程,方便炼丹。

一、 首先介绍一个简单的例子:
  1. 加载头文件:
import torch
import torch.utils.data as Data
torch.manual_seed(1)
  1. 生成torch数据
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
  1. 将生成的数据做成一个DataSet和Dataloader
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset = torch_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = 2
)
  1. 利用Dataloader来迭代数据
BATCH_SIZE = 5
for epoch in range(3):
    for step, (batchX, batchY) in enumerate(loader):
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
              batchX.numpy(), '| batch y: ', batchY.numpy())

输出:

Epoch:  0 | Step:  0 | batch x:  [ 4.  6.  7. 10.  8.] | batch y:  [7. 5. 4. 1. 3.]
Epoch:  0 | Step:  1 | batch x:  [5. 3. 2. 1. 9.] | batch y:  [ 6.  8.  9. 10.  2.]
Epoch:  1 | Step:  0 | batch x:  [ 4.  2.  5.  6. 10.] | batch y:  [7. 9. 6. 5. 1.]
Epoch:  1 | Step:  1 | batch x:  [3. 9. 1. 8. 7.] | batch y:  [ 8.  2. 10.  3.  4.]
Epoch:  2 | Step:  0 | batch x:  [ 4. 10.  9.  8.  7.] | batch y:  [7. 1. 2. 3. 4.]
Epoch:  2 | Step:  1 | batch x:  [6. 1. 2. 5. 3.] | batch y:  [ 5. 10.  9.  6.  8.]
二、batchsize 不能被 数据长度整除

上面一个玩具例子中,我们可以发现batchsize=5, 数据长度为10,刚好两个step可以取尽数据。如果batchsize=8呢,我们发现,第二次迭代数据时,数据长度只剩下2

loader = Data.DataLoader(
    dataset = torch_dataset,
    batch_size = 8,
    shuffle = True,
    num_workers = 2,
    drop_last=True
)
for epoch in range(3):
    for step, (batchX, batchY) in enumerate(loader):
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
              batchX.numpy(), '| batch y: ', batchY.numpy())

输出:

Epoch:  0 | Step:  0 | batch x:  [10.  2.  9.  5.  6.  4.  8.  7.] | batch y:  [1. 9. 2. 6. 5. 7. 3. 4.]
Epoch:  0 | Step:  1 | batch x:  [1. 3.] | batch y:  [10.  8.]
Epoch:  1 | Step:  0 | batch x:  [7. 2. 8. 9. 6. 5. 3. 1.] | batch y:  [ 4.  9.  3.  2.  5.  6.  8. 10.]
Epoch:  1 | Step:  1 | batch x:  [10.  4.] | batch y:  [1. 7.]
Epoch:  2 | Step:  0 | batch x:  [ 1.  6.  3.  7. 10.  8.  4.  2.] | batch y:  [10.  5.  8.  4.  1.  3.  7.  9.]
Epoch:  2 | Step:  1 | batch x:  [9. 5.] | batch y:  [2. 6.]

可以发现最后只迭代余下2个数据(10-8)。

那么我们如果不想要这两个数据怎么办呢,那么在构造dataloader的时候设置drop_last=True

loader = Data.DataLoader(
    dataset = torch_dataset,
    batch_size = 8,
    shuffle = True,
    num_workers = 2,
    drop_last=True
)
for epoch in range(3):
    for step, (batchX, batchY) in enumerate(loader):
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
              batchX.numpy(), '| batch y: ', batchY.numpy())

输出

Epoch:  0 | Step:  0 | batch x:  [ 6.  5.  7.  3.  8. 10.  9.  2.] | batch y:  [5. 6. 4. 8. 3. 1. 2. 9.]
Epoch:  1 | Step:  0 | batch x:  [ 1. 10.  5.  2.  4.  6.  9.  8.] | batch y:  [10.  1.  6.  9.  7.  5.  2.  3.]
Epoch:  2 | Step:  0 | batch x:  [3. 4. 1. 8. 6. 5. 2. 7.] | batch y:  [ 8.  7. 10.  3.  5.  6.  9.  4.]

三、关于DataSet和DataLoader

我们看到pytorch加载数据主要是用到了DataSet及DataLoader,这里简要介绍DataSet及DataLoader。

Dataset

源码来自Pytorch文档:
https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#Dataset

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

上述代码用到了TensorDataSet, 这是DataSet的子类。

class TensorDataset(Dataset):
    """Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

一共3个成员函数,init, getitem, 及len。分别用来初始化,getitem用来返回每个数据(注意这里是每个), len用来返回数据长度。

这里会有一个问题:上面玩具例子我们随便写了长度为10的数据,然后赋给DataSet。但是实际上,我们的数据量非常多,一次性加载到内存上,内存会爆炸,然后赋值DataSet基本不太可能。因此我们需要自己写一个DataSet的子类,这后面再讲。

我们需要明白的是,如果我们要自己构造子类,只需要学着TensorDataset, 构造三个成员函数就行了,分别是__init__, __ getitem __, __ len __。

DataLoader

源码来自:https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader
我们现在只需要传入的参数就行了。

DataLoader的参数列表:
dataset: Dataset的类或者派生类
batch_size : batchsize, 每个batch的大小
shuffle: 是否打乱数据
sampler:定义从dataset取数据的策略,一般来说选择默认
num_workers: 多线程读取数据,num_worker多少就代表多少线程读取
collate_fn: 将Dataset中的单个数据拼成batch的数据
drop_last:是否将最后不足一个batch的数据丢弃
timeout:如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
pin_memory:锁页内存,一般来说,在GPU训练的时候设置成True,在CPU上设置成False。

pin_memory就是锁页内存,创建DataLoader时,设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
主机中的内存,有两种存在方式,一是锁页,二是不锁页,锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。
而显卡中的显存全部是锁页内存!
当计算机的内存充足的时候,可以设置pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False。因为pin_memory与电脑硬件性能有关,pytorch开发者不能确保每一个炼丹玩家都有高端设备,因此pin_memory默认为False。


class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

    Arguments:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: ``1``).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).
        sampler (Sampler, optional): defines the strategy to draw samples from
            the dataset. If specified, ``shuffle`` must be False.
        batch_sampler (Sampler, optional): like sampler, but returns a batch of
            indices at a time. Mutually exclusive with :attr:`batch_size`,
            :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
        num_workers (int, optional): how many subprocesses to use for data
            loading. 0 means that the data will be loaded in the main process.
            (default: ``0``)
        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
            into CUDA pinned memory before returning them.  If your data elements
            are a custom type, or your ``collate_fn`` returns a batch that is a custom type
            see the example below.
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: ``False``)
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: ``0``)
        worker_init_fn (callable, optional): If not ``None``, this will be called on each
            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
            input, after seeding and before data loading. (default: ``None``)

    .. note:: When ``num_workers != 0``, the corresponding worker processes are created each time
              iterator for the DataLoader is obtained (as in when you call
              ``enumerate(dataloader,0)``).
              At this point, the dataset, ``collate_fn`` and ``worker_init_fn`` are passed to each
              worker, where they are used to access and initialize data based on the indices
              queued up from the main process. This means that dataset access together with
              its internal IO, transforms and collation runs in the worker, while any
              shuffle randomization is done in the main process which guides loading by assigning
              indices to load. Workers are shut down once the end of the iteration is reached.

              Since workers rely on Python multiprocessing, worker launch behavior is different
              on Windows compared to Unix. On Unix fork() is used as the default
              muliprocessing start method, so child workers typically can access the dataset and
              Python argument functions directly through the cloned address space. On Windows, another
              interpreter is launched which runs your main script, followed by the internal
              worker function that receives the dataset, collate_fn and other arguments
              through Pickle serialization.

              This separate serialization means that you should take two steps to ensure you
              are compatible with Windows while using workers
              (this also works equally well on Unix):

              - Wrap most of you main script's code within ``if __name__ == '__main__':`` block,
                to make sure it doesn't run again (most likely generating error) when each worker
                process is launched. You can place your dataset and DataLoader instance creation
                logic here, as it doesn't need to be re-executed in workers.
              - Make sure that ``collate_fn``, ``worker_init_fn`` or any custom dataset code
                is declared as a top level def, outside of that ``__main__`` check. This ensures
                they are available in workers as well
                (this is needed since functions are pickled as references only, not bytecode).

              By default, each worker will have its PyTorch seed set to
              ``base_seed + worker_id``, where ``base_seed`` is a long generated
              by main process using its RNG. However, seeds for other libraies
              may be duplicated upon initializing workers (w.g., NumPy), causing
              each worker to return identical random numbers. (See
              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
              use :func:`torch.initial_seed()` to access the PyTorch seed for
              each worker in :attr:`worker_init_fn`, and use it to set other
              seeds before data loading.

    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                 unpicklable object, e.g., a lambda function.

    The default memory pinning logic only recognizes Tensors and maps and iterables
    containg Tensors.  By default, if the pinning logic sees a batch that is a custom type
    (which will occur if you have a ``collate_fn`` that returns a custom batch type),
    or if each element of your batch is a custom type, the pinning logic will not
    recognize them, and it will return that batch (or those elements)
    without pinning the memory.  To enable memory pinning for custom batch or data types,
    define a ``pin_memory`` method on your custom type(s).

    Example::

        class SimpleCustomBatch:
            def __init__(self, data):
                transposed_data = list(zip(*data))
                self.inp = torch.stack(transposed_data[0], 0)
                self.tgt = torch.stack(transposed_data[1], 0)

            def pin_memory(self):
                self.inp = self.inp.pin_memory()
                self.tgt = self.tgt.pin_memory()
                return self

        def collate_wrapper(batch):
            return SimpleCustomBatch(batch)

        inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
        tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
        dataset = TensorDataset(inps, tgts)

        loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                            pin_memory=True)

        for batch_ndx, sample in enumerate(loader):
            print(sample.inp.is_pinned())
            print(sample.tgt.is_pinned())

    """

    __initialized = False

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            self.batch_size = None
            self.drop_last = None

        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers option cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.__initialized = True

    def __setattr__(self, attr, val):
        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
            raise ValueError('{} attribute should not be set after {} is '
                             'initialized'.format(attr, self.__class__.__name__))

        super(DataLoader, self).__setattr__(attr, val)

    def __iter__(self):
        return _DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

你可能感兴趣的:(pytorch)