MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。我们将使用类似但更复杂的Fashion-MNIST数据集
%matplotlib inline
import torch
import torchvision
from torch.utils import data
#transforms 对数据操作的包
from torchvision import transforms
#存在d2l
from d2l import torch as d2l
d2l.use_svg_display()
通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式
# 并处以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="./data/FashionMNIST", train=True,transform=trans,download = False)
mnist_test = torchvision.datasets.FashionMNIST(root="./data/FashionMNIST", train=False,transform=trans,download = False)
len(mnist_train), len(mnist_test)
(60000, 10000)
mnist_train[0][0].shape
torch.Size([1, 28, 28])
def get_fashion_mnist_labels(labels):
"""返回Fashion-MNIST数据集的文本标签。"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
#这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images),figsize=(12,12))
for f,img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axus.get_yaxis().set_visible(False)
plt.show()
def show_image(imgs, num_row=1, num_col=1, titles=None, scale=1.5):
# 设置图片大小
figsize = (num_col * scale, num_row * scale)
# 这里的 _ 表示忽略不使用的变量、即fig
_, axes = d2l.plt.subplots(num_row, num_col, figsize=figsize)
for i, (img, label) in enumerate(zip(imgs, titles)):
# 计算图片的位置、需要用到整除和除余
xloc, yloc = i // num_col, i % num_col
# 判断传入的图片是否为张量
if torch.is_tensor(img):
axes[xloc, yloc].imshow(img.reshape((28, 28)).numpy())
else:
axes[xloc, yloc].imshow(img)
# 设置标题并取消横纵坐标上的刻度
axes[xloc, yloc].set_title(label)
axes[xloc, yloc].set(xticks=[], yticks=[])
batch_size = 18
# data.DataLoader( )函数的作用在于根据传入的数据集和批量大小来返回小批量数据集
X, y = next(iter(data.DataLoader(mnist_train, batch_size=batch_size)))
show_image(X.reshape(batch_size, 28, 28), num_row=2, num_col=9,
titles=get_fashion_mnist_labels(y))
batch_size = 256
def get_dataloader_workers():
"""使用4个进程来读取的数据。"""
return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())
timer = d2l.Timer()
for X,y in train_iter:
continue
f'{timer.stop():.2f} sec'