pytorch – 数据读取机制中的Dataloader与Dataset

引自

pytorch - 数据读取机制中的Dataloader与Dataset - 全栈程序员必看

怎么建立一个预测模型呢?考虑上一个博客中的机器学习模型训练五大步骤;第一是数据,第二是模型,第三是损失函数,第四是优化器,第五个是迭代训练过程。

这里主要学习数据模块当中的数据读取,数据模块通常还会分为四个子模块,数据收集、数据划分、数据读取、数据预处理。

在进行实验之前,需要收集数据,数据包括原始样本和标签;

有了原始数据之后,需要对数据集进行划分,把数据集划分为训练集、验证集和测试集;训练集用于训练模型,验证集用于验证模型是否过拟合,也可以理解为用验证集挑选模型的超参数,测试集用于测试模型的性能,测试模型的泛化能力;

第三个子模块是数据读取,也就是这里要学习的DataLoader,pytorch中数据读取的核心是DataLoader;

第四个子模块是数据预处理,把数据读取进来往往还需要对数据进行一系列的图像预处理,比如说数据的中心化,标准化,旋转或者翻转等等。pytorch中数据预处理是通过transforms进行处理的;

第三个子模块DataLoader还会细分为两个子模块,Sampler和DataSet;Sample的功能是生成索引,也就是样本的序号;Dataset是根据索引去读取图片以及对应的标签;

这里主要学习第三个子模块中的Dataloader和Dataset;

2、DataLoader与Dataset

DataLoader和Dataset是pytorch中数据读取的核心;

2.1) DataLoader

(1)torch.utils.data.DataLoader

  • 功能:构建可迭代的数据装载器;
  • dataset:Dataset类,决定数据从哪里读取及如何读取;
  • batchsize:批大小;
  • num_works:是否多进程读取数据;
  • shuffle:每个epoch是否乱序;
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;

Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

样本总数:80,Batchsize:8 (样本能被Batchsize整除)
1 Epoch = 10 Iteration

样本总数:87,Batchsize=8 (样本不能被Batchsize整除)
1 Epoch = 10 Iteration,drop_last = True
1 Epoch = 11 Iteration, drop_last = False

(2)torch.utils.data.Dataset

  • Dataset是用来定义数据从哪里读取,以及如何读取的问题;
  • 功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__();
  • getitem:接收一个索引,返回一个样本

你可能感兴趣的:(pytorch,深度学习,人工智能)