pytorch实现图像分类任务-手写数字识别(一)

Pytorch手写数字识别

Minst数据集介绍

Size: 28×28 灰度手写数字图像
Num: 训练集 60000 和 测试集 10000,一共70000张图片
Classes: 0,1,2,3,4,5,6,7,8,9

pytorch实现图像分类任务-手写数字识别(一)_第1张图片
一共包含四个文件夹:

train-images-idx3-ubyte.gz:训练集图像(9912422 字节)55000张训练集 + 5000张验证集;
train-labels-idx1-ubyte.gz:训练集标签(28881 字节)训练集对应的标签;
t10k-images-idx3-ubyte.gz:测试集图像(1648877 字节)10000张测试集;
t10k-labels-idx1-ubyte.gz:测试集标签(4542 字节)测试集对应的标签;

读取数据集并进行绘制

PyTorch加载数据:Dataset和Dataloader

  • dataset:取数据以及其对应的label,并未其添加索引

    提供一种方式去获取数据及其label
    实现:
    (1)如何获取每一个数据及其label
    (2)告诉我们总共有多少的数据

  • dataloader:为后面的网络提供不同的数据形式

  1. datasets.MNIST是Pytorch的内置函数torchvision.datasets.MNIST,通过这个可以导入数据集。

  2. train=True 代表我们读入的数据作为训练集(如果为true则从training.pt创建数据集,否则从test.pt创建数据集download=True则是当我们的根目录(root)下没有数据集时,便自动下载。

  3. transform则是读入我们自己定义的数据预处理操作

  4. 如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:

下载测试集和训练集

from torchvision import datasets, transforms

train_data = datasets.MNIST(root="./MNIST", 
                            train=True, 
                            transform=transforms.ToTensor(), 
                            download=True)

test_data = datasets.MNIST(root="./MNIST", 
                           train=False, 
                           transform=transforms.ToTensor(), 
                           download=True)
print(train_data)
print(test_data)

pytorch实现图像分类任务-手写数字识别(一)_第2张图片

matplotlib进行数据可视化

DataLoader前置知识

使用 DataLoader 为训练准备数据

Dataset每次检索数据集中的一组特征和标签。然而,在训练模型时,通常希望

以“小批量”的形式传递样本,这种小批量集合了数个样本(比如16个、64个

等);并且在每个回合( epoch)都能打乱数据,以减少模型过拟合,并使用

Python的multiprocessing加速数据检索。在这种需求的引导下,DataLoader应

运而生。PyTorch提供了一个简单易用的API——DataLoader,能实现上述数据
进入训练的复杂过程的高度抽象化。

DataLoader其实可以理解成一个高效率的迭代器(iterater),并被PyTorch很好的封装。他的使用非常简单:

from torch.utils.data import DataLoader

#训练数据的DataLoder,小批量数量64,打乱数据可选项为True
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

#测试数据的DataLoder,小批量数量64,打乱数据可选项为True
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
遍历 DataLoader

将该数据集加载到DataLoader后,可以根据需要对数据集进行迭代和遍历。下面的每次迭代都会返回一批train_特征和train_标签(分别包含batch_size=64个特征和标签)。因为已经指定了shuffle=True,所以在迭代所有批处理之后,数据会被打散。如需要更细粒度地控制数据加载顺序,可以使用Samplers)。

迭代和显示特征与标签(image and label)

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

数据可视化具体实现

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

train_data = datasets.MNIST(root="./MNIST",
                            train=True,
                            transform=transforms.ToTensor(),
                            download=False)

train_loader = DataLoader(dataset=train_data,
                          batch_size=64,
                          shuffle=True)

for num, (image, label) in enumerate(train_loader):
    image_batch = torchvision.utils.make_grid(image, padding=2)
    plt.imshow(np.transpose(image_batch.numpy(), (1, 2, 0)), vmin=0, vmax=255)
    plt.show()
    print(label)

pytorch实现图像分类任务-手写数字识别(一)_第3张图片

pytorch实现图像分类任务-手写数字识别(一)_第4张图片
总结:通过循环遍历训练集的形式:可以查看数据集的图片形式和label张量的形式。

你可能感兴趣的:(人工智能算法,pytorch,分类,python)