Dataloader的使用详解

目录

官网介绍的Dataloader

实践部分


DataLoader会自动将目标数据样本划分为多个批次,并根据需要进行数据预处理、数据增强等操作,同时也可以在数据加载过程中进行多线程并行加载,以提高数据加载效率。

Dataloader的使用详解_第1张图片

如果将图像数据集比喻为一幅扑克牌,那么dataset就是一整个数据集,而dataloader是一个加载器,就是说把数据加载到神经网络中,如果把图片中的手比作神经网络,那么每次去取几张牌、几只手去取、怎么取,这个取的过程就是 dataloader 的工作。

官网介绍的Dataloader

首先我们看下官网对Dataloader的相关介绍,“点此查看”

Dataloader的使用详解_第2张图片

常用参数解释:

  • dataset:将自定义的dataset实例化参数传入,进行类的继承
  • batch_size:每次输入进行训练时的图像分批次传入的数量(即每批次的图像数量)
  • shuffle:布尔值为True,表示此次epoch与前面训练的数据顺序不一致,若为False,表示与前面的epoch一致
  • num_works:表示进行数据加载时使用单个进程还是多进程进行加载,多进程意为加载速度更快,一般默认为0,表示使用主进程进行加载
  • drop_last:当剩余的图像数量已不足batch_size时,若布尔值为True,则将不足部分舍去,若为False,则仍需加载图像数据

注:当使用多进程加载数据出现 BrokenPipeError 报错时,可尝试将num_works设置为0。

实践部分

使用DataLoader的主要步骤如下:

  • 创建一个数据集对象,通常是torch.utils.data.Dataset的子类,以提供数据样本。
  • 使用数据集对象创建一个DataLoader实例,设置合适的batch_size、shuffle和num_workers等参数。
  • 在训练过程中,使用for循环遍历DataLoader实例,迭代获取每个批次的数据样本。

这一部分我们主要看一下 shuffle 和 drop_last 两个参数的作用。首先是drop__last,在这里我们分别将其设置为True和False,使用 tensorboard 查看具体的输出结果。Tensorboard的使用方法在这:PyTorch深度学习快速入门

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

输出结果如下:

Dataloader的使用详解_第3张图片

从图中可以看出当drop__last=True时,最后输出的结果均是批次为8*8的大小,自动舍去了不足的部分;而当drop__last=False时,仍会读取所有图像。

然后是shuffle,当shuffle=True时,每一次epoch时图像数据传入的顺序和前面的均不一致,这大大提高模型的泛化性。输出结果如下:

Dataloader的使用详解_第4张图片

相关代码如下:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

img, target = test_data[0]
print(img.shape) 
print(target)  

writer = SummaryWriter("dataloader")
for epoch in range(2):  
    step = 0
    for data in test_loader: 
        imgs, targets = data
        writer.add_images("Epoch: {}".format(epoch), imgs, step)
        step = step + 1
writer.close()

总结:Dataloader 只是作为神经网络模型的输入端口,继承 dataset 类,将目标数据分批次进行输送,这可以帮助我们高效地加载和组织大规模的数据,并在训练模型时提供数据的批量处理。

你可能感兴趣的:(PyTorch深度学习快速入门,python,pytorch)