Sample和DataSet是DataLoader的两个子模块。Sampler的功能主要是生成索引。也就是样本的序号。
D a t a s e t Dataset Dataset是根据索引去读取数据以及对应的标签。DataLoader负责以特定的方式从数据集中迭代的产生一个一个 b a t c h batch batch集合。其中。DataLoader和Dataset是pytorch中数据读取的核心。
(以特定的方式从数据集中迭代产生一个一个的batch集合》
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
实例化一个DataLoader所需参数如上所示:其中:
Dataset就是一个负责处理索引(index)到样本(sample)映射的一个类(class).
i n d e x → s a m p l e index \rightarrow sample index→sample
t o r c h . u t i l s . d a t a . D a t a s e t torch.utils.data.Dataset torch.utils.data.Dataset 是一个表示数据集的抽象的类,任何自定义的数据集都需要继承这个类并腹泻相关方法。
pytorch:提供两种数据集: M a p 式 数 据 集 Map式数据集 Map式数据集、 l t e r a b l e 式 数 据 集 lterable式数据集 lterable式数据集。
一个Map式的数据集必须要重写getitem(self, index),len(self) 两个内建方法,用来表示从索引到样本的映射(Map)
g e t i t e m ( s e l f , i n d e x ) , l e n ( s e l f ) getitem(self, index),len(self) getitem(self,index),len(self)两个内建方法。
用来表示从索引到样本的映射(Map).
这样一个数据集dataset。举个例子,当使用 d a t a s e t [ i d x ] dataset[idx] dataset[idx]命令时,可以在你的硬盘中读取数据集中的第 i d x idx idx张图片以及其标签,
l e n ( d a t a s e t ) len(dataset) len(dataset):则返回这个数据集的容量。
class CustomDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
getitem最主要的方法是,其规定了如何读取数据,但是又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假定你定义好一个dataset,那么你可以直接通过Dataset[0]来访问第一个数据。
一个lterable式数据集,是抽象类:data.IterableDataset的子类,
并且腹泻了iter方法成为一个迭代器。这种数据集主要用于数据大小未知,或者以流的形式输入。本地文件不固定的情况,需要以迭代的方式来获取样本索引。
迭代器是一个可以记住遍历的位置的对象,迭代器对象从集合的第一个元素开始访问。直到所有的元素被访问完结束,迭代器只能往前而不会后退,迭代器两个基本方法: i t e r ( ) iter() iter()和 n e x t ( ) next() next().
i t e r ( ) iter() iter():方法返回一个特殊的迭代器对象,这个迭代器对象实现了next()方法,并通过$StopIteration $异常标识迭代的完成。
n e x t ( ) next() next() 方法会返回迭代器的输出。
class MyNumbers:
def __iter__(self):
self.a = 1
return self
def __next__(self):
if self.a <= 20:
x = self.a
self.a += 1
return x
else:
raise StopIteration
#StopIteration 异常用于标识迭代的完成,防止出现无限循环的情况,在 next() 方法中我们可以设置在完成指定循环次数后触发 StopIteration 异常来结束迭代。
myclass = MyNumbers()
myiter = iter(myclass)
for x in myiter:
print(x)
sampler类的源代码主要由三种方法,如下:
class Sampler(object):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""
# 一个 迭代器 基类
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
按顺序对数据集采样,其原理首先在初始化的时候,拿到数据集data_source,之后在__iter__方法中首先得到一个和data_source一样长度的range迭代器。每次只返回一个索引值。
随机采样
子集随机采样,用于训练,测试集和验证集合的划分。
加权随机采样。
前面的采样器每次只返回一个索引,但是我们在训练时是对批量数据进行训练。而这样的工作都需要BatchSampler来做。也就是说BatchSampler的作用就是将前面的Sampler采样得到的索引值进行合并,当数量等于一个batch大小后就将这一批的索引值返回。
慢慢的将各种采样方法,全部都将其搞定。慢慢的将其研究透彻,研究彻底。都行啦的样子与打算。