mnist torch加载fashion_PyTorch - 02 - fashion-MNIST数据集的使用

FashionMNIST数据集

Fashion-MNIST是一个10类服饰分类数据集, 我们可以使用它来检验不同算法的表现, 这是MNIST数据集不能做到的(原因在这里,想了解的可以看看介绍)。

torchvision的结构

torchvision包包含了很多图像相关的数据集以及处理方法, 并且有常用的模型结构。

torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

torchvision.datasets: 一些加载数据的函数及常用的数据集接口;

torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;

torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;

torchvision.utils: 其他的一些有用的方法。

# 导入需要的包

import torch

import torchvision

import torchvision.transforms as transforms

from torchvision.datasets import FashionMNIST

import matplotlib.pyplot as plt

加载数据

设置数据的缓存目录为 root_dir

随后获得训练集和测试集数据,第一次运行的时候会下载 FashionMNIST 数据集到指定的目录下

将Fashion-MNIST/ data / fashion的四个压缩文件解压到指定的目录,不要删除原来的压缩包文件,因此数据集总共有八个文件。

# 通过标签得到描述语句

def get_f_mnist_labels(labels):

"""

:param labels: 图片对应的标签(0-9的数字)

:return: 标签对应的描述

"""

text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']

return [text_labels[int(i)] for i in labels]

def show_fashion_mnist(images, labels):

"""

:param images: 读取的图片

:param labels: 图片对应的标签

:return: None, 输出图片,并且在图片上方对应标签给出描述

"""

_, figs = plt.subplots(1, len(images), figsize=(12, 2))

for f, img, lbl in zip(figs, images, labels):

f.imshow(img.view((28, 28)))

f.set_title(lbl)

f.axes.get_xaxis().set_visible(False)

f.axes.get_yaxis().set_visible(False)

plt.show()

root_dir = "./torchvision/data/"

f_mnist_train = FashionMNIST(root=root_dir, train=True, download=True, transform=transforms.ToTensor())

f_mnist_test = FashionMNIST(root=root_dir, train=False, download=True, transform=transforms.ToTensor())

print("f_mnist_train length:", len(f_mnist_train), end='\n')

print("f_mnist_test length:", len(f_mnist_test), end='\n')

x, y = [], []

for i in range(10):

x.append(f_mnist_train[i][0])

y.append(f_mnist_train[i][1])

show_fashion_mnist(x, get_f_mnist_labels(y))

f_mnist_train length: 60000

f_mnist_test length: 10000

读取小批量数据

from torch.utils.data import DataLoader

batch_size = 256

train_iter = DataLoader(f_mnist_train, batch_size, shuffle=True, num_workers = 0)

# 计算加载数据的时间

import time

start = time.time()

for X, y in train_iter:

continue

print("read train data cost %.4f seconds" % (time.time()-start))

read train data cost 4.9213 seconds

注意

本章的介绍思路来自 Apple Store的 “Python AI” app, 作为学习目的使用, 以及在此文章中记录学习过程(如有侵权,请联系作者删除。)

你可能感兴趣的:(mnist,torch加载fashion)