本文是学习《动手学深度学习(pytorch)》“3.5 图像分类数据集(Fashion-MNIST)” 的笔记,具体解释请参考原文。
使用到的包主要是torchvision
,它主要由以下几部分构成:
torchvision.datasets
:一些加载数据的函数及常用的数据集接口;torchvision.models
:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;torchvision.transforms
:常用的图片变换,例如裁剪、旋转等;torchvision.utils
:其他的一些有用的方法。1、下载数据
当不使用transform=torchvision.transforms.ToTensor()
时,获取到的数据是尺寸为(H×W×C)且数据位于[0, 255]之间的 PIL 图像或者数据类型为 unit8 的 Numpy 数组。
该语句将上述类型的数据,转换为尺寸为(C×H×W)且数据类型为 torch.float32 且位于[0.0, 1.0]之间的 Tensor 。
import torchvision
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=torchvision.transforms.ToTensor())
2、获取数据标签
下载完数据后,需要能够找到数据对应的标签。以下函数可以将数值标签转成相应的文本标签。
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
torch.utils
中data
的一个方法DataLoader
能够很方便的读取 batch_size
大小的数据,三个常用的三个参数分别是dataset、batch_size、shuffle(是否不按顺序读取数据)
import torch.utils.data as Data
batch_size = 256
train_iter = Data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
test_iter = Data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)
下面定义一个可以在一行里画出多张图像和对应标签的函数。
# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()