深度学习图像分类数据集Fashion-MNIST

加载数据集时,由于网速实在太慢,直接使用书中代码记载不出来,可以先将Fashion-MNIST数据集下载下来,使用这两行代码,前边root是存放数据集的路径

mnist_train = gdata.vision.FashionMNIST(root=r'C:/Users/Wu/AppData/Roaming/mxnet/datasets/fashion-mnist/',train=True)
mnist_test = gdata.vision.FashionMNIST(root=r'C:/Users/Wu/AppData/Roaming/mxnet/datasets/fashion-mnist/',train=False)

这里注意,下载一定要下载官网的数据集,不然验证不通过,具体下载方法:使用书中代码运行后,pycharm的运行界面会显示下载网址,将网址在浏览器打开即可。

完整代码及注释如下

import d2lzh as d2l
import matplotlib.pyplot as plt
from mxnet.gluon import data as gdata
import sys
import time

#下载FashionMNIST数据集
mnist_train = gdata.vision.FashionMNIST(root=r'C:/Users/Wu/AppData/Roaming/mxnet/datasets/fashion-mnist/',train=True)
mnist_test = gdata.vision.FashionMNIST(root=r'C:/Users/Wu/AppData/Roaming/mxnet/datasets/fashion-mnist/',train=False)

print(mnist_train,len(mnist_train))

feature, label = mnist_train[0:9]
print(feature.shape,feature.dtype)
print(label,type(label),label.dtype)

#数值标签转文本
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

print(get_fashion_mnist_labels(label))

def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    #_表示不使用的变量,1行len(images)个大小12*12的子图
    _, figs = d2l.plt.subplots(1, len(images), figsize=(12, 12))
    #zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
    #a = [1,2,3],b = [4,5,6],zipped = zip(a,b) 后[(1, 4), (2, 5), (3, 6)]
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.reshape((28, 28)).asnumpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()#必须加这一句才会显示图像

X, y = mnist_train[0:9]
show_fashion_mnist(X, get_fashion_mnist_labels(y))

#小批量读取,使用DataLoader
# 通过ToTensor实例将图像数据从uint8格式变换成32位浮点数格式,并除以255使得所有像素的数值均在0到1之间。
# 同时,ToTensor实例还将图像通道从最后一维移到最前一维。通过数据集的transform_first函数,
# 将ToTensor的变换应用在每个数据样本(图像和标签)的第一个元素,即图像之上
batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
    num_workers = 0 #表示不用额外进程来加速读取数据
else:
    num_workers = 4

train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True,num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=False,num_workers=num_workers)

#读取一遍数据花的时间
start = time.time()
for X, y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

 

你可能感兴趣的:(深度学习图像分类数据集Fashion-MNIST)