pytorch学习笔记6-数据读取机制Dataloader与Dataset

机器学习训练步骤:

数据-模型-损失函数-优化器-迭代训练

数据:

1.数据收集:图像和标签

2.数据划分:分为训练集(训练模型)、验证集(验证模型是否过拟合,简而言之用验证集来挑选模型)、测试集(测试挑选出来模型的性能)

3.数据读取:Dataloader(pytorch数据读取的核心就是Dataloader)

         Dataloader:Sampler(生成索引,也就是样本的序号)和Dataset(根据索引读取图片和标签)

4.数据预处理:transforms

 

Dataloader与Dataset

(1)Dtaloader

torch.utils.data.DataLoader

功能:搭建可迭代的数据装载器

参数(常用):

dataset:Dataset类,决定数据从哪儿读取及如何读取

batchsize:批大小

num_works:是否多进程读取数据

shuffle:每个epoch是否乱序

drop_last:当样本数不能被batchsize整除时,是否放弃最后一批数据

关系:

Epoch:所有训练样本都已输入到模型中,称为一个Epoch

Iteration:一批样本输入到模型中,称之为一个Iterration

Batchsize:批大小,决定一个Epoch有多少个Iteration

例如:样本总数为160,Batchsize=8

1Epoch=20个Iteration

当样本总数不能被Batchsize整除时,就要看drop_last的取值

例如:样本总数为165,Batchsize=8

1Epoch=20Iteration (darp_last=True)

1Epoch=21Iteration (darp_last=False)

(2)Dataset

torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写

__getitem__()

getitem:

         接受一二索引,返回一个样本

数据读取机制问题:

1.读那些数据? 采样器sampler输出的index

2.从哪儿读数据? dataset中的data_dir

3.怎么读数据?  dataset中的getitem

大致如流程图所示:

pytorch学习笔记6-数据读取机制Dataloader与Dataset_第1张图片

读取数据首先在for循环中使用DataLoader,进入DataLoader之后会根据使用单进程或者多进程进入到DataLoaderItem,然后会使用Sampler获取我们的索引index,得到索引之后给到DatasetFetcher,在这里面会调用Dataset,Dataset会根据我们给的索引在getiem冲我们的硬盘里面取读取我们的图像和标签img,label,我们读到一个batchsize大小的数据之后通过一个collate_fun将我们这些数据进行整理成一个Batch Data的形式,然后就可以输入到模型中取训练了。

 

 

你可能感兴趣的:(pytorch学习)