加载数据集时,由于网速实在太慢,直接使用书中代码记载不出来,可以先将Fashion-MNIST数据集下载下来,使用这两行代码,前边root是存放数据集的路径
mnist_train = gdata.vision.FashionMNIST(root=r'C:/Users/Wu/AppData/Roaming/mxnet/datasets/fashion-mnist/',train=True)
mnist_test = gdata.vision.FashionMNIST(root=r'C:/Users/Wu/AppData/Roaming/mxnet/datasets/fashion-mnist/',train=False)
这里注意,下载一定要下载官网的数据集,不然验证不通过,具体下载方法:使用书中代码运行后,pycharm的运行界面会显示下载网址,将网址在浏览器打开即可。
完整代码及注释如下
import d2lzh as d2l
import matplotlib.pyplot as plt
from mxnet.gluon import data as gdata
import sys
import time
#下载FashionMNIST数据集
mnist_train = gdata.vision.FashionMNIST(root=r'C:/Users/Wu/AppData/Roaming/mxnet/datasets/fashion-mnist/',train=True)
mnist_test = gdata.vision.FashionMNIST(root=r'C:/Users/Wu/AppData/Roaming/mxnet/datasets/fashion-mnist/',train=False)
print(mnist_train,len(mnist_train))
feature, label = mnist_train[0:9]
print(feature.shape,feature.dtype)
print(label,type(label),label.dtype)
#数值标签转文本
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
print(get_fashion_mnist_labels(label))
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
#_表示不使用的变量,1行len(images)个大小12*12的子图
_, figs = d2l.plt.subplots(1, len(images), figsize=(12, 12))
#zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
#a = [1,2,3],b = [4,5,6],zipped = zip(a,b) 后[(1, 4), (2, 5), (3, 6)]
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.reshape((28, 28)).asnumpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()#必须加这一句才会显示图像
X, y = mnist_train[0:9]
show_fashion_mnist(X, get_fashion_mnist_labels(y))
#小批量读取,使用DataLoader
# 通过ToTensor实例将图像数据从uint8格式变换成32位浮点数格式,并除以255使得所有像素的数值均在0到1之间。
# 同时,ToTensor实例还将图像通道从最后一维移到最前一维。通过数据集的transform_first函数,
# 将ToTensor的变换应用在每个数据样本(图像和标签)的第一个元素,即图像之上
batch_size = 256
transformer = gdata.vision.transforms.ToTensor()
if sys.platform.startswith('win'):
num_workers = 0 #表示不用额外进程来加速读取数据
else:
num_workers = 4
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True,num_workers=num_workers)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=False,num_workers=num_workers)
#读取一遍数据花的时间
start = time.time()
for X, y in train_iter:
continue
print('%.2f sec' % (time.time() - start))