Pytorch框架学习记录5——DataLoader的使用

Pytorch框架学习记录5——DataLoader的使用

1. DataLoader方法介绍

Pytorch官网上对DataLoader方法进行了详细的介绍,数据加载器。结合数据集和采样器,并提供给定数据集的可迭代对象。DataLoader支持具有单进程或多进程加载、自定义加载顺序和可选的自动批处理(整理)和内存固定的地图样式和可迭代样式数据集。

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device='')

参数

  • dataset ( Dataset ) – 从中加载数据的数据集。
  • batch_size ( int , optional ) – 每批要加载多少样本(默认值:1)。
  • shuffle ( bool , optional ) – 设置为True在每个 epoch 重新洗牌数据(默认值:False)。
  • num_workers ( int , optional ) – 用于数据加载的子进程数。0表示数据将在主进程中加载。(默认:0
  • drop_last ( bool , optional ) –True如果数据集大小不能被批次大小整除,则设置为丢弃最后一个不完整的批次。如果False数据集的大小不能被批大小整除,那么最后一批将更小。(默认:False

2. 实例

这里使用CIFAR10数据集,通过DataLoader方法将数据集以64一组打包,在windows系统中num_workers=0,最后在tensorboard中将打包好的图像展示。

注意,对于打包的图片展示,使用的方法是add_images()方法,单张图片展示使用add_image()方法

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

train_set = torchvision.datasets.CIFAR10(root='C:\\Users\\hp\\PycharmProjects\\pythonProject\\Pytorch_Learning\\p11-dataset_transform\\dataset',
                             train=True, transform=torchvision.transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

writer = SummaryWriter("logs")
step = 0
for data in train_loader:
    img, target = data
    writer.add_images("test_data", img, step)
    step += 1

writer.close()

Pytorch框架学习记录5——DataLoader的使用_第1张图片

你可能感兴趣的:(Pytorch学习笔记,pytorch,学习,深度学习)