【深度学习】Fashion-MNIST数据集简介

文章目录

    • 数据集简介
    • 操作
      • 下载数据集
      • 数据格式
      • 可视化显示
      • 读取小批量
    • 完整代码
    • 备注

数据集简介

不同于MNIST手写数据集,Fashion-MNIST数据集包含了10个类别的图像,分别是:t-shirt(T恤),trouser(牛仔裤),pullover(套衫),dress(裙子),coat(外套),sandal(凉鞋),shirt(衬衫),sneaker(运动鞋),bag(包),ankle boot(短靴)。

操作

下载数据集

这里通过mxnet中gluon的data包来下载这个数据集,第一次调用时将自动从网上下载获取数据。默认的保存位置为:C:\Users\Username\.mxnet\datasets\fashion-mnist\目录(Linux系统估计也在主目录下的类似目录中),共包含四个文件,分别为:

  • 训练数据图片train-images-idx3-ubyte
  • 训练数据标签train-labels-idx1-ubyte
  • 测试数据图片t10k-images-idx3-ubyte
  • 测试数据标签t10k-labels-idx1-ubyte

因为是国外的网站,没有代理的话下载速度很慢,所以可以通过手动下载并并将其移动至上述目录:

  1. 点击程序运行时控制台消息给出的链接进行下载;
  2. 点击下载我上传至CSDN的副本(CSDN资源至少需要1积分,没办法啦各位筒子)。

train指定获取训练数据集还是测试数据集。

from mxnet.gluon import data as gdata

# 通过gluon的data包下载数据集
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)

查阅mxnet的官方文档,可以看到vision中预定义了以下几种数据集:
【深度学习】Fashion-MNIST数据集简介_第1张图片

数据格式

训练数据集每个类别含有6000个样本,测试数据集每个类别含有1000个样本,一共有10个类别,故而训练数据集共60000个样本,测试数据集共10000个样本。

图像是一个28*28的像素数组,每个像素的值为0~255之间的8位无符号整数(uint8),使用三维NDArray存储,最后一维表示通道个数。由于为灰度图像,故通道数为1。

# 10个类别,训练集和测试集每个类别分别6000样本和1000样本
print(len(mnist_train), len(mnist_test))
# 数据集形状
feature, lable = mnist_train[0]
print(feature.shape, feature.dtype)
print(lable, lable.dtype)

运行结果如下:
在这里插入图片描述

可视化显示

首先定义一个函数,根据数值标签获取字符串标签:

# 将数据集的数值标签转换为文本标签, 参数是一个list
def get_fasion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

定义函数,在一行中绘制数据代表的可视化结果:

import gluonbook as gb
import pylab

# 同一行画出图片和标签
def show_fashion_mnist(images, labels):
    gb.use_svg_display()
    # _表示忽略不使用的变量
    _, figs = gb.plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.reshape((28, 28)).asnumpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_xaxis().set_visible(False)
    pylab.show()    # PyCharm 2018中绘图


# 画图示例
X, y = mnist_train[0:9]
show_fashion_mnist(X, get_fasion_mnist_labels(y))

结果如下:
【深度学习】Fashion-MNIST数据集简介_第2张图片

读取小批量

使用gluon中的DataLoader来读取数据,它可以允许使用多进程来加速数据的读取(windows暂时不支持)。下面通过num_workers设置读取数据的进程数。

ToTensor类将图像数据从uint8转换为32位浮点数格式,并且除以255使得所有的像素值分布于[0,1]。ToTensor类还将图像通道从最后一维移到最前面一维上来,这将方便在卷积神经网络中进行计算。通过数据集的transform_first函数,我们将ToTensor的变换应用在每个数据样本(图像和标签)的第一个元素,即图像之上。

import sys
import time

# 读取小批量数据
batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
    # 0表示不使用额外的进程读取数据
    num_workers = 0
else:
    num_workers = 4
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True, num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=False, num_workers=num_workers)

start = time.time()
for X, y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

我电脑上的数据读取时间:
在这里插入图片描述
改变batch_size的值,看看数据的读取性能会有何影响?

完整代码

# coding=utf-8
# author: BebDong
# 2018/12/18

from mxnet.gluon import data as gdata
import gluonbook as gb
import pylab
import sys
import time


# 将数据集的数值标签转换为文本标签
def get_fasion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


# 同一行画出图片和标签
def show_fashion_mnist(images, labels):
    gb.use_svg_display()
    # _表示忽略不使用的变量
    _, figs = gb.plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.reshape((28, 28)).asnumpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_xaxis().set_visible(False)
    pylab.show()


# 通过gluon的data包下载数据集
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)

# 10个类别,训练集和测试集每个类别分别6000样本和1000样本
print(len(mnist_train), len(mnist_test))
# 数据集形状
feature, lable = mnist_train[0]
print(feature.shape, feature.dtype)
print(lable, lable.dtype)
print(get_fasion_mnist_labels([lable]))

# 画图示例
X, y = mnist_train[0:9]
# show_fashion_mnist(X, get_fasion_mnist_labels(y))

# 读取小批量数据
batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
    # 0表示不使用额外的进程读取数据
    num_workers = 0
else:
    num_workers = 4
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True, num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=False, num_workers=num_workers)

start = time.time()
for X, y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

备注

本文为《动手学深度学习》的学习笔记,原书链接:http://zh.diveintodeeplearning.org/chapter_deep-learning-basics/fashion-mnist.html

代码中的gluonbook包是这本书籍封装的工具包,它把书中描述的将来会复用的所有方法进行了封装。如果您并未按照书中给出的配置文件搭建环境,那可以点击这里单独下载gluonbook包,并将其移动至恰当的位置。

你可能感兴趣的:(人工智能)