动手学深度学习——图像分类数据集Fashion-MNIST

Fashion-MNIST是⼀个10类服饰分类数据集。 

torchvision 包:它是服务于 PyTorch 深度学习框架的,主要⽤来构建计算机视觉模型。
torchvision 主要由以下⼏部分构成:
  •  torchvision.datasets : ⼀些加载数据的函数及常⽤的数据集接⼝;
  •  torchvision.models : 包含常⽤的模型结构(含预训练模型),例如AlexNetVGG、 ResNet等;
  •  torchvision.transforms : 常⽤的图⽚变换,例如裁剪、旋转等;
  •  torchvision.utils : 其他的⼀些有⽤的⽅法。
1、加载常用包
import os
import matplotlib.pyplot as plt
# 其中matplotlib包可用于作图,且设置成嵌入式
import torch
import torchvision #torchvision这个库 它是一Pytorch对于个计算机识别一些模型实现的一个库。
import torchvision.transforms as transforms #对数据进行操作的一个模具。
import matplotlib.pyplot as plt
import time
import sys
from IPython import display

2、获取数据集

"""
通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中
通过ToTensor实例将图像数据从PLL类型变换成32位浮点数格式,最简单的一个预处理 transform=transforms.ToTensor()
"""

# 训练数据集
#从torchvision中的datasets中将Fashion-MNIST数据集拿到;root是目录;train=True表示下载的是训练数据集;download=True表示确定从网上下载;
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',
                                                train=True, download=True,
                                                transform=transforms.ToTensor())

# 测试数据集
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',
                                               train=False, download=True,
                                               transform=transforms.ToTensor())
# 上⾯的 mnist_train 和 mnist_test 都是 torch.utils.data.Dataset 的⼦类,所以我们可以⽤ len() 来获取该数据集的⼤⼩,还可以⽤下标来获取具体的⼀个样本。
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))


# 我们可以通过下标来访问任意一个样本
feature, label = mnist_train[0]
print(feature.shape, label)
# 变量feature对应的高和宽均为28像素的图像,输出显示的第一维是通道数,因为数据集是灰度图像,所以通道数为1,后面两维分别是图像的宽和高。

动手学深度学习——图像分类数据集Fashion-MNIST_第1张图片

3、 输出训练集中的10个样本的图像内容和文本标签

"""
Fashion-MNIST中⼀共包括了10个类别,分别为t-shirt(T恤)、trouser(裤⼦)、pullover(套衫)、
dress(连⾐裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、
bag(包)和ankle boot(短靴)。以下函数可以将数值标签转成相应的⽂本标签。
"""
def get_fashion_mnist_labels(labels):
    text_lables = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_lables[int(i)] for i in labels]
#定义绘图函数
def use_svg_display():
    # 用矢量图显示
    display.set_matplotlib_formats('svg')

def set_figsize(figsize=(3.5, 2.5)):
    use_svg_display()
    # 设置图的尺寸
    plt.rcParams['figure.figsize'] = figsize
# 下⾯定义⼀个可以在⼀⾏⾥画出多张图像和对应标签的函数。
def show_fashion_mnist(images, labels):
    use_svg_display()
    # 这⾥的_表示我们忽略(不使⽤)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()
# 输出训练集中的10个样本的图像内容和文本标签
X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

动手学深度学习——图像分类数据集Fashion-MNIST_第2张图片

4、读取小批量

#num_work来设置4个进程读取数据
batch_size=256
if sys.platform.startswith('win'):
    num_workers=0 #0表示不用额外的进程来加速读取数据
else:
    num_workers=4
train_iter=torch.utils.data.DataLoader(mnist_train,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers)
test_iter=torch.utils.data.DataLoader(mnist_test,
                                      batch_size=batch_size,
                                      shuffle=True,#随机
                                      num_workers=num_workers) #短进程
#最后查看读取一遍训练数据需要的时间
start=time.time()
for X,y in train_iter:
    continue
print('%.2f sec'% (time.time()- start))

 动手学深度学习——图像分类数据集Fashion-MNIST_第3张图片

 

你可能感兴趣的:(动手学深度学习,线性代数,深度学习,人工智能,pytorch)