Size: 28×28 灰度手写数字图像
Num: 训练集 60000 和 测试集 10000,一共70000张图片
Classes: 0,1,2,3,4,5,6,7,8,9
train-images-idx3-ubyte.gz:训练集图像(9912422 字节)55000张训练集 + 5000张验证集;
train-labels-idx1-ubyte.gz:训练集标签(28881 字节)训练集对应的标签;
t10k-images-idx3-ubyte.gz:测试集图像(1648877 字节)10000张测试集;
t10k-labels-idx1-ubyte.gz:测试集标签(4542 字节)测试集对应的标签;
PyTorch加载数据:Dataset和Dataloader
dataset:取数据以及其对应的label,并未其添加索引
提供一种方式去获取数据及其label
实现:
(1)如何获取每一个数据及其label
(2)告诉我们总共有多少的数据
dataloader:为后面的网络提供不同的数据形式
datasets.MNIST是Pytorch的内置函数torchvision.datasets.MNIST,通过这个可以导入数据集。
train=True 代表我们读入的数据作为训练集(如果为true则从training.pt创建数据集,否则从test.pt创建数据集download=True则是当我们的根目录(root)下没有数据集时,便自动下载。
transform则是读入我们自己定义的数据预处理操作
如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:
from torchvision import datasets, transforms
train_data = datasets.MNIST(root="./MNIST",
train=True,
transform=transforms.ToTensor(),
download=True)
test_data = datasets.MNIST(root="./MNIST",
train=False,
transform=transforms.ToTensor(),
download=True)
print(train_data)
print(test_data)
使用 DataLoader 为训练准备数据
Dataset每次检索数据集中的一组特征和标签。然而,在训练模型时,通常希望
以“小批量”的形式传递样本,这种小批量集合了数个样本(比如16个、64个
等);并且在每个回合( epoch)都能打乱数据,以减少模型过拟合,并使用
Python的multiprocessing加速数据检索。在这种需求的引导下,DataLoader应
运而生。PyTorch提供了一个简单易用的API——DataLoader,能实现上述数据
进入训练的复杂过程的高度抽象化。
DataLoader其实可以理解成一个高效率的迭代器(iterater),并被PyTorch很好的封装。他的使用非常简单:
from torch.utils.data import DataLoader
#训练数据的DataLoder,小批量数量64,打乱数据可选项为True
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
#测试数据的DataLoder,小批量数量64,打乱数据可选项为True
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
遍历 DataLoader
将该数据集加载到DataLoader后,可以根据需要对数据集进行迭代和遍历。下面的每次迭代都会返回一批train_特征和train_标签(分别包含batch_size=64个特征和标签)。因为已经指定了shuffle=True,所以在迭代所有批处理之后,数据会被打散。如需要更细粒度地控制数据加载顺序,可以使用Samplers)。
迭代和显示特征与标签(image and label)
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
train_data = datasets.MNIST(root="./MNIST",
train=True,
transform=transforms.ToTensor(),
download=False)
train_loader = DataLoader(dataset=train_data,
batch_size=64,
shuffle=True)
for num, (image, label) in enumerate(train_loader):
image_batch = torchvision.utils.make_grid(image, padding=2)
plt.imshow(np.transpose(image_batch.numpy(), (1, 2, 0)), vmin=0, vmax=255)
plt.show()
print(label)