PyTorch—— 图像分类数据集(Fashion-MNIST)

PyTorch—— 图像分类数据集(Fashion-MNIST)

    • 0、前言
    • 一、获取数据集
    • 二、读取小批量数据
    • 三、画图

本文是学习《动手学深度学习(pytorch)》“3.5 图像分类数据集(Fashion-MNIST)” 的笔记,具体解释请参考原文。

0、前言

使用到的包主要是torchvision,它主要由以下几部分构成:

  1. torchvision.datasets:一些加载数据的函数及常用的数据集接口;
  2. torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms:常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils:其他的一些有用的方法。

一、获取数据集

1、下载数据
当不使用transform=torchvision.transforms.ToTensor()时,获取到的数据是尺寸为(H×W×C)且数据位于[0, 255]之间的 PIL 图像或者数据类型为 unit8 的 Numpy 数组

该语句将上述类型的数据,转换为尺寸为(C×H×W)且数据类型为 torch.float32 且位于[0.0, 1.0]之间的 Tensor

import torchvision

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=torchvision.transforms.ToTensor())

2、获取数据标签

下载完数据后,需要能够找到数据对应的标签。以下函数可以将数值标签转成相应的文本标签。

def get_fashion_mnist_labels(labels):
	text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
	return [text_labels[int(i)] for i in labels]

二、读取小批量数据

torch.utilsdata的一个方法DataLoader能够很方便的读取 batch_size 大小的数据,三个常用的三个参数分别是dataset、batch_size、shuffle(是否不按顺序读取数据)

import torch.utils.data as Data

batch_size = 256
train_iter = Data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
test_iter = Data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)

三、画图

下面定义一个可以在一行里画出多张图像和对应标签的函数。

# 本函数已保存在d2lzh包中方便以后使用
def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

你可能感兴趣的:(PyTorch,学习历程)