在 pytorch 中加载和使用图像分类数据集 Fashion-MNIST

  • 参考:《动手学深度学习》(Pytorch)版 3.5 节
  • 注:本文是 jupyter notebook 文档转换而来,部分代码可能无法直接复制运行!

文章目录

  • 1. 获取数据集
  • 2. 读取小批量

  • 图像分类数据集中最常用的是手写数字识别数据集MNIST,但大部分模型在MNIST上的分类精度都超过了95%,为了更直观地观察算法之间的差异,本文介绍一个图像内容更加复杂的数据集 Fashion-MNIST,这个数据集难度比 MNIST 高,但是尺寸并不大,只有几十M,没有GPU的电脑也能吃得消
  • 该数据集可以利用 torchvision 包来下载和处理,该包包含以下几个核心模块
    1. torchvision.datasets: 提供加载数据的函数及常用数据集接口;
    2. torchvision.models: 包含常用的模型结构(含预训练模型),如 AlexNet、VGG、ResNet 等;
    3. torchvision.transforms: 提供常用的图片变换方法,例如裁剪、旋转等;
    4. torchvision.utils: 提供其他的一些有用的方法
  • 开始介绍前,先导入包
    import torch
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    import time
    import numpy as np
    from IPython import display
    

1. 获取数据集

  • 通过 torchvision.datasets.FashionMNIST 方法获取数据集

    mnist_train = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=True, transform=transforms.ToTensor())
    mnist_test = torchvision.datasets.FashionMNIST(root='./Datasets/FashionMNIST', train=False, transform=transforms.ToTensor())
    

    参数说明

    1. root 参数指定数据集保存路径

    2. train 参数指定获取训练集还是测试集

    3. download 参数若设置为 True,则在发现 root 路径下没有数据集时自动从网上下载,若已有数据集则不动作

    4. transform = transforms.ToTensor() 使所有数据转换为 Tensor,如果不转换则返回的是 PIL 图片

      transforms.ToTensor() 将 “尺寸为 H × W × C H \times W \times C H×W×C 且数据位于 [ 0 , 255 ] [0, 255] [0,255] 的PIL图片” 或者 “数据类型为 np.uint8 的NumPy数组” 转换为 “尺寸为 C × H × W C \times H \times W C×H×W 且数据类型为 torch.float32 且位于 [0.0, 1.0] 的Tensor”

      注意 transforms.ToTensor() 在内的一些关于图片的函数默认输入为 uint8 类型,如果不是则可能得到不想要的结果,所以如果用 [ 0 , 255 ] [0,255] [0,255] 的像素值表示图片数据,则一律将其类型设置为 uint8,以免不必要的bug

  • 这里加载的 mnist_trainmnist_test 都是 torch.utils.data.Dataset 的子类,一些常用方法如下

    print(type(mnist_train))
    print(len(mnist_train), len(mnist_test)) # 用 len() 获取该数据集的大小
    
    feature, label = mnist_train[0]          # 通过下标来访问任意样本
    print(feature.shape, label)              # [Channel , Height , Width] label,注意由于数据集中都是灰度图,通道数为 1
    
    '''
    torchvision.datasets.mnist.FashionMNIST
    60000 10000
    torch.Size([1, 28, 28]) 9
    '''
    
  • Fashion-MNIST中一共包括了10个类别,分别为

    1. t-shirt(T恤)
    2. trouser(裤子)
    3. pullover(套衫)
    4. dress(连衣裙)
    5. coat(外套)
    6. sandal(凉鞋)
    7. shirt(衬衫)
    8. sneaker(运动鞋)
    9. bag(包)
    10. ankle boot(短靴)

    使用以下函数将数值标签列表转成相应的文本标签列表

    def get_fashion_mnist_labels(labels):
        text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                       'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
        return [text_labels[int(i)] for i in labels]
    
  • 使用以下函数在一行里绘制多个图像和对应的标签

    def show_fashion_mnist(images, labels):
        display.set_matplotlib_formats('svg')  # Use svg format to display plot in jupyter
        
        _, figs = plt.subplots(1, len(images), figsize=(12, 12))
        for f, img, lbl in zip(figs, images, labels):
            f.imshow(img.view((28, 28)).numpy())
            f.set_title(lbl)
            f.axes.get_xaxis().set_visible(False)
            f.axes.get_yaxis().set_visible(False)
        plt.show()
    
  • 随机显示 10 个样本

    X, y = [], []
    for i in np.random.randint(0,60000,size = 10).tolist():
        X.append(mnist_train[i][0])
        y.append(mnist_train[i][1])
    show_fashion_mnist(X, get_fashion_mnist_labels(y))
    

    这里我遇到一个报错,请参考 ‘OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program’,我删除了虚拟环境中的 libiomp5md.dll 解决此问题

在这里插入图片描述

2. 读取小批量

  • 在实践中,数据读取经常是训练的性能瓶颈,torch.utils 模块提供的 DataLoader 方法允许我们方便地使用多进程来加速数据读取

  • mnist_traintorch.utils.data.Dataset 的子类,所以我们可以将其传入 torch.utils.data.DataLoader 来创建一个读取小批量数据样本的DataLoader 实例,在创建时

    1. 通过参数 num_workers 来指定读取数据的进程数量
    2. 通过 shuffle 参数指定读取时是否打乱
    batch_size = 256
    if sys.platform.startswith('win'): # 判断操作系统为 windows
        num_workers = 4 # 使用 4 个进程同时读取
    else:
        num_workers = 0 # 0表示不用额外的进程来加速读取数据
    
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
  • 查看读取一遍数据的耗时

    start = time.time()
    for X, y in train_iter:
        continue
    print('%.2f sec' % (time.time() - start))
    

    经测试,我的笔记本电脑在不使用多进程加速时耗时 5.88s,使用后减少到 3.18s

你可能感兴趣的:(#,PyTorch,数据集,Fashion-MNIST,pytorch)