动手学深度学习 图像分类数据系列:
Fashion-MNIST在书中多次使用,本文的内容是讲解如何获取并查看此数据集
使用torchvision.datasets
来下载数据集
root
用来指定下载后保存的位置(如果已经存在则不会下载)download
表示是否要下载train
表示获取训练数据集或测试数据集transform
代表对图像的操作, 这里仅仅使用了ToTensor()
把图像数据转换为Tensor
类型更多transform的操作可以点击这篇文章来查看
书本原话:
注意:由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括
transforms.ToTensor() 在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错
但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成
uint8,避免不必要的bug。
import torchvision
import torchvision.transforms as transforms
mnist_train = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root=r'D:\Source\Datasets\FashionMNIST', train=False, download=True,
transform=transforms.ToTensor())
对训练集切片查看一下数据类型和标签类型
这里的标签已经转换为数值型数据来存储
所以我们可以编写一个函数将其转换为 图像数据集原本对应的标签
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
先提取出其中的一张图片与标签来查看
img, label = mnist_train[0]
title = get_fashion_mnist_labels([label])[0] # 获取标签
plt.imshow(img.view((28,28)).numpy()) # 数据格式转换
plt.title(title) # 设置标题
plt.savefig('test.jpg') # 存储图片
import matplotlib.pyplot as plt
def show_fashion_mnist(images, labels):
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
plt.show()
使用DataLoader
它可以允许多线程来加速数据读取
具体的可以看下面链接中的文章,有对DataLoader
和Dataset
的详细介绍
Pytorch 快速详解如何构建自己的Dataset完成数据预处理(附详细过程)
from torch.utils.data import DataLoader
import sys
batch_size = 256
if sys.platform.startswith('win'):
# 0表示不用额外的进程来加速读取数据
num_workers = 0
else:
num_workers = 4
train_iter = DataLoader(mnist_train,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
test_iter = DataLoader(mnist_test,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
DataLoader
是个可遍历的对象
start = time()
for X, y in train_iter:
continue
print('%.2f sec' % (time() - start))
可以通过上述代码来查看读取一遍训练集需要的时间
本文内容来自吴振宇博士的Github项目
对中文版《动手学深度学习》中的代码进行整理,并用Pytorch实现
【深度学习】李沐《动手学深度学习》的PyTorch实现已完成