pytorch笔记5-数据读取机制DataLoader

简介

p y t o r c h pytorch pytorch的数据读取机制 D a t a L o a d e r DataLoader DataLoader包括两个子模块: S a m p l e r Sampler Sampler模块、主要生成索引 i n d e x index index D a t a S e t DataSet DataSet模块,主要根据索引读取数据, D a t a s e t Dataset Dataset类是一个抽象类,其可以用来表示数据集,我们通过继承 D a t a s e t Dataset Dataset类来自定义数据集格式、大小和其它属性。后面都可以供 D a t a L o a d e r DataLoader DataLoader类直接使用。
在实际项目中,如果数据量很大,考虑到内存有限,I/O速度等问题,在训练过程中不可能一次性的将所有数据全部加载到内存中,也不能只用一个进程去加载,所以都需要多进程,迭代加载。而 D a t a L o a d e r DataLoader DataLoader都是基于这些需要被设计出来的, D a t a L o a d e r DataLoader DataLoader是一个迭代器,最基本的使用方法都是传入一个 D a t a s e t Dataset Dataset对象,其会根据参数 b a t c h s i z e batch_size batchsize值生成一个batch数据,节省内存的同时,还可以实现多进程、数据打乱等处理。

torch.utils.data.Dataset
class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError

D a t a s e t Dataset Dataset是用来解决数据从哪里读取以及如何读取的问题, p y t o r c h pytorch pytorch给定的 D a t a s e t Dataset Dataset是一个抽象类,所有自定义的 D a t a s e t Dataset Dataset都需要继承它,并且复写__getitem__和__len__()l类的方法,getitem()作用接受一个索引,返回一个样本或者标签,下面通过一个实例构造一个数据集:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)
    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]

结合代码可以看到,我们定义了一个名字为 M y D a t a s e t MyDataset MyDataset的数据集,在构造函数中,传入 T e n s o r Tensor Tensor类型的数据和标签,在KaTeX parse error: Expected group after '_' at position 1: _̲_len__函数中,直接返回 T e n s o r Tensor Tensor大小,在KaTeX parse error: Expected group after '_' at position 1: _̲_getitem__函数中返回索引的数据和标签。

接下来,我们看如何调用刚才定义的数据集,先随机生成一个 10*3 维的数据 Tensor,然后生成 10 维的标签 Tensor,与数据 Tensor 相对应。利用这两个 Tensor,生成一个 MyDataset 的对象。查看数据集的大小可以直接用 len() 函数,索引调用数据可以直接使用下标

# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1

# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)

# 查看数据集大小
print('Dataset size:', len(my_dataset))
'''
输出:
Dataset size: 10
'''

# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])
'''
输出:
tensor_data[0]:  (tensor([ 0.4931, -0.0697,  0.4171]), tensor(0))

D a t a L o a d e r DataLoader DataLoader的功能构建可迭代的数据装载器,在训练的时候,每一个for循环,每一次 l t e r a t i o n lteration lteration 就是从 D a t a s e t Dataset Dataset中获取一个 b a t c h s i z e batch_size batchsize大小的数据:

torch.utils.data.DataLoader
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,
pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)

DataLoader有很多参数,但常用的有下面五个:

  • dataset表示Dataset类,它决定了数据从哪读取以及如何读取;
  • batch_size表示批大小;
  • num_works表示是否多进程读取数据;
  • shuffle表示每个epoch是否乱序;
  • drop_last表示当样本数不能被batch_size整除时,是否舍弃最后一批数据;
    这里提到了 e p o c h epoch epoch,所有训练样本都已输入模型之中,称为一个 E p o c h Epoch Epoch,也就是说将样本都训练一遍,称为一个 e p o c h epoch epoch,一批样本输入到模型中,这样称为一个 l t e r a t i o n lteration lteration,决定了一个 E p o c h Epoch Epoch有多少个 l t e r a t i o n lteration lteration,为批次大小的 B a t c h s i z e Batchsize Batchsize

流程图。

通过下面留程图,我们认识下数据读取机制。
pytorch笔记5-数据读取机制DataLoader_第1张图片

总结

慢慢的将数据读取机制,全部都将其搞清楚,研究彻底,

你可能感兴趣的:(Torch的使用及参数解释,pytorch,python,深度学习)