DataLoader、Dataset and Sampler

目录

一、DataLoader、DataSet和Sampler之间的关系

二、Dataloader

三、DataSet

1、Map式数据集

2、Iterable式数据集

⭐迭代器

三、Sampler

(1)SequentialSampler

(2)RandomSampler

(3)SubsetRandomSampler

(4)WeightedRandomSampler

(5)BatchSampler

四、总结


一、DataLoader、DataSet和Sampler之间的关系

Sampler和DataSet是DataLoader的两个子模块;Sampler的功能是生成索引,也就是样本的序号;Dataset是根据索引去读取数据以及对应的标签。DataLoader负责以特定的方式从数据集中迭代的产生 一个个batch的样本集合。其中,DataLoader和Dataset是pytorch中数据读取的核心。

二、Dataloader

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

实例化一个DataLoader所需的参数如上所示。其中:

1、dataset:定义好的Map式或者Iterable式数据集;

2、batch_size:一个batch中的样本个数,默认为1;

3、shuffle:每一个epoch的batch样本是相同还是随机;

4、sampler:数据集中采样的方法. 如果有,则shuffle参数必须为False;

5、batch_sampler:和 sampler 类似,但是一次返回的是一个batch内所有样本的index;

6、num_workers:多少个子程序同时工作来获取数据,多线程;

7、collate_fn:合并样本列表以形成小批量;

8、pin_menory:如果为True,数据加载器在返回前将张量复制到CUDA固定内存中;

9、drop_last:如果数据集大小不能被batch_size整除,设置为True可删除最后一个不完整的批处理。如果设为False并且数据集的大小不能被batch_size整除,则最后一个batch将更小;

10、timeout:如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。numeric应总是大于等于0;

三、DataSet

DataSet就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。Pytorch提供两种数据集:Map式数据集 和Iterable式数据集

1、Map式数据集

一个Map式的数据集必须要重写getitem(self, index),len(self) 两个内建方法,用来表示从索引到样本的映射(Map)。.

这样一个数据集dataset,举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签(如果有的话);len(dataset)则会返回这个数据集的容量。

自定义类的结构一般如下:

class CustomDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

 __getitem__是最主要的方法,它规定了如何读取数据。但是它又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假如你定义好了一个dataset,那么你可以直接通过dataset[0]来访问第一个数据。

2、Iterable式数据集

一个Iterable(迭代)式数据集是抽象类data.IterableDataset的子类,并且覆写了iter方法成为一个迭代器。这种数据集主要用于数据大小未知,或者以流的形式的输入,本地文件不固定的情况,需要以迭代的方式来获取样本索引。

⭐迭代器

迭代器是一个可以记住遍历的位置的对象。迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。迭代器有两个基本的方法:iter() 和 next()。

iter() 方法返回一个特殊的迭代器对象, 这个迭代器对象实现了 next() 方法并通过 StopIteration 异常标识迭代的完成。

next() 方法会返回迭代器的输出。

创建迭代器的一般格式为:

class MyNumbers:
  def __iter__(self):
    self.a = 1
    return self

  def __next__(self):
    if self.a <= 20:
      x = self.a
      self.a += 1
      return x
    else:
      raise StopIteration
#StopIteration 异常用于标识迭代的完成,防止出现无限循环的情况,在 next() 方法中我们可以设置在完成指定循环次数后触发 StopIteration 异常来结束迭代。

myclass = MyNumbers()
myiter = iter(myclass)

for x in myiter:
  print(x)

三、Sampler

Sampler类的源代码主要有三种方法,如下:

class Sampler(object):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """
    # 一个 迭代器 基类
    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError
  • __init__: 就是初始化
  • __iter__: 用来产生迭代索引值的,也就是指定每个step需要读取哪些数据
  • __len__: 用来返回每次迭代器的长度

Pytorch提供了给我们几种采样器,如下:

(1)SequentialSampler

按顺序对数据集采样。其原理是首先在初始化的时候拿到数据集data_source,之后在__iter__方法中首先得到一个和data_source一样长度的range可迭代器。每次只会返回一个索引值。

(2)RandomSampler

随机采样

(3)SubsetRandomSampler

子集随机采样,用于训练集、测试集和验证集的划分

(4)WeightedRandomSampler

加权随机采样

(5)BatchSampler

前面的采样器每次都只返回一个索引,但是我们在训练时是对批量的数据进行训练,而这个工作就需要BatchSampler来做。也就是说BatchSampler的作用就是将前面的Sampler采样得到的索引值进行合并,当数量等于一个batch大小后就将这一批的索引值返回。

四、总结

以上都是我在学习时针对我现在所需做的笔记,如果需要更加细节的了解,请参考下面我参考的文章:

一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系 - marsggbo - 博客园

极市开发者平台-计算机视觉算法开发落地平台

Pytorch Sampler详解 - 知乎

Sampler类与4种采样方式_Wanderer001的博客-CSDN博客_sample采样

你可能感兴趣的:(深度学习,python代码,python,机器学习)