PyTorch|数据读取机制之Dataloader与Dataset

01 | Dataloader与DataSet数据读取方法

DataLoader与DataSet是PyTorch数据读取的核心。

torch.utils.DataLoader”的作用是构建一个可迭代的数据装载器,每次执行循环的时候,就从中读取一批Batchsize大小的样本进行训练。

其主要参数有五项:

  • dataset:隶属DataSet类,表示数据从哪里读取以及如何读取

  • batchsize:批大小

  • num_works:是否多进程读取数据

  • shuffle:每个epoch是否乱序

  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

上述主要参数中num_works通常在单进程时默认为“0”,也可以在支持多进程的设备上设置为允许的“4 | 8 | 16”等。shuffle则通常设置为使用乱序(True),以使得每次数据读取具有随机性。

这里颇为重要的是“Epoch、Iteration和Batchsize”之间的关系:1)Epoch表示所有训练样本都已输入到模型中,记为一个Epoch;2)Iteration表示一批样本输入到模型中,记为一个Iteration;3)Batchsize表示批大小,决定一个Epoch中有多少个Iteration。当样本数可以被Batchsize整除时,三者成立关系,即全体样本分成Batchsize分批次输入模型,每批次记为一次Iteration。

若样本总数80个,当Batchsize=8时,可以知道“1 Epoch = 10 Iteration”。
若样本总数87个,当Batchsize-8时,可以知道:1)若“drop_last=True”,则“1 Epoch = 10
Iteration”;2)若“drop_last=False”,则“1 Epoch = 11 Iteration”,其最后一个Iteration时样本个数为7,小于既定Batchsize。

torch.utils.data.Dataset”主要用于定义数据从哪里读取以及如何读取的问题。其定义为DataSet抽象类,所有自定义的Dataset都需要继承它,并复写“getitem()”内构函数,该函数接受一个索引,并返回一个样本。

PyTorch|数据读取机制之Dataloader与Dataset_第1张图片

02 | DataLoader与DataSet数据读取机制

PyTorch的数据读取机制通常围绕三个核心问题展开,即:

  • 读取哪些数据?
  • 从哪里读取数据?
  • 怎么读取数据?

事实上,通过在PyCharm中进行代码调试,我们可以简要回答上述问题:1)通过Sampler取样器按序或随机挑选出Batchsize数量的索引列表;2)使用DataSet中的data_dir指定硬盘上的数据访问路径;3)使用DataSet中自定义的getitem()方法,基于Sampler返回的索引列表读取相应数据和标签,并拼接成新的列表数据。


事实上,PyTorch的数据读取经过了诸多函数的跳转。在for循环中首先调用了“DataLoader”,进而使用Sampler、Dataset和getitem解决“数据读哪些?从哪读?怎么读?”的问题。最后,我们提供一份PyTorch中DataLoader数据读取机制的函数跳转流程图,供大家参考学习。

PyTorch|数据读取机制之Dataloader与Dataset_第2张图片

你可能感兴趣的:(AI+Security,pytorch,python,人工智能)