「@Author:Runsen」
有时候,在处理大数据集时,一次将整个数据加载到内存中变得非常难。
因此,唯一的方法是将数据分批加载到内存中进行处理,这需要编写额外的代码来执行此操作。对此,PyTorch 已经提供了 Dataloader 功能。
下面显示了 PyTorch
库中DataLoader
函数的语法及其参数信息。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
几个重要参数
dataset
:必须首先使用数据集构造 DataLoader 类。
Shuffle
:是否重新整理数据。
Sampler
:指的是可选的 torch.utils.data.Sampler 类实例。采样器定义了检索样本的策略,顺序或随机或任何其他方式。使用采样器时应将 Shuffle 设置为 false。
Batch_Sampler
:批处理级别。
num_workers
:加载数据所需的子进程数。
collate_fn
:将样本整理成批次。Torch 中可以进行自定义整理。
MNIST 是一个著名的包含手写数字的数据集。下面介绍如何使用DataLoader
功能处理 PyTorch 的内置 MNIST 数据集。
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
上面代码,导入了 torchvision
的torch计算机视觉模块。通常在处理图像数据集时使用,并且可以帮助对图像进行规范化、调整大小和裁剪。
对于 MNIST 数据集,下面使用了归一化技术。
ToTensor()能够把灰度范围从0-255变换到0-1之间。
transform = transforms.Compose([transforms.ToTensor()])
下面代码用于加载所需的数据集。使用 PyTorchDataLoader
通过给定 batch_size = 64
来加载数据。shuffle=True
打乱数据。
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
为了获取数据集的所有图像,一般使用iter函数和数据加载器DataLoader
。
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)
print(labels.shape)
plt.imshow(images[1].numpy().squeeze(), cmap='Greys_r')
下面的代码创建一个包含 1000 个随机数的自定义数据集。
from torch.utils.data import Dataset
import random
class SampleDataset(Dataset):
def __init__(self,r1,r2):
randomlist=[]
for i in range(120):
n = random.randint(r1,r2)
randomlist.append(n)
self.samples=randomlist
def __len__(self):
return len(self.samples)
def __getitem__(self,idx):
return(self.samples[idx])
dataset=SampleDataset(1,100)
dataset[100:120]
在这里插入图片描述
最后,将在自定义数据集上使用 dataloader
函数。将 batch_size
设为 12,并且还启用了num_workers =2
的并行多进程数据加载。
from torch.utils.data import DataLoader
loader = DataLoader(dataset,batch_size=12, shuffle=True, num_workers=2 )
for i, batch in enumerate(loader):
print(i, batch)
通过几个示例了解了 PyTorch Dataloader 在将大量数据批量加载到内存中的作用。
往期精彩回顾
适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》课件合集
本站qq群851320808,加入微信群请扫码: