mnist torch加载fashion_MNIST数据集详解及可视化处理(pytorch)

mnist torch加载fashion_MNIST数据集详解及可视化处理(pytorch)_第1张图片

MNIST 数据集已经是一个被”嚼烂”了的数据集, 作为机器学习在视觉领域的“hello world”,很多教程都会对它”下手”, 几乎成为一个 “典范”。 不过有些人可能对它还不是很了解, 下面来介绍一下。

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。

mnist torch加载fashion_MNIST数据集详解及可视化处理(pytorch)_第2张图片

在使用pytorch进行学习时,可以使用pytorch的处理图像视频的torchvision工具集直接下载MNIST的训练和测试图片,torchvision包含了一些常用的数据集、模型和转换函数等等,比如图片分类、语义切分、目标识别、实例分割、关键点检测、视频分类等工具。

from torchvision import datasets, transforms

#下载测试集
train_dataset = datasets.MNIST('./data', train=True, 
                                transfrom=transforms.ToTensor(), 
                                download=True)
test_dataset =  datasets.MNIST('./data', train=False, 
                                transform=transforms.ToTensor(),
                                download=True)

下载完成后的数据集如下图所示。

数据集的图片是以字节的形式进行存储,在我们进行训练和测试时可以直接使用 torch.utils.data.DataLoader 进行加载。

mnist torch加载fashion_MNIST数据集详解及可视化处理(pytorch)_第3张图片

虽然下载下来的数据集文件,其具体的存储格式我们暂时不用太过关心,但如何才能将这部分数据转换为可见的图片形式呢?我们可以利用pytorch自带的工具进行文件读取,并提取数据保存为可打开的jpg文件和txt文件。

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="D:/MNIST/data/MNIST/raw"
train_set = (
    mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
        )
test_set = (
    mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
        )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
    if(train):
        f=open(root+'train.txt','w')
        data_path=root+'/train/'
        if(not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
            img_path=data_path+str(i)+'.jpg'
            io.imsave(img_path,img.numpy())
            f.write(img_path+' '+str(label)+'n')
        f.close()
    else:
        f = open(root + 'test.txt', 'w')
        data_path = root + '/test/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
            img_path = data_path+ str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            f.write(img_path + ' ' + str(label) + 'n')
        f.close()

convert_to_img(True)#转换训练集
convert_to_img(False)#转换测试集

等待转换完成后,可以在MNIST训练集和测试集所在的文件夹内出现train和test两个文件夹。

mnist torch加载fashion_MNIST数据集详解及可视化处理(pytorch)_第4张图片

里面是已经转换完成的jpg格式的60000张训练数据和10000张测试数据。在test和train文件夹的上一层raw文件夹中也产生了相对应的两个txt文件,里面是每张jpg格式的图片所标注的数字。

mnist torch加载fashion_MNIST数据集详解及可视化处理(pytorch)_第5张图片

此时我们已经能看到所有的手写字图片和相对应的标签了,下一步可以使用pytorch来实现LeNet-5网络,利用训练数据集对卷积神经网络进行训练,训练完成后可以使用测试集对网络进行测试,以检验网络的训练结果。

当然,也可以自行任意选择图片输送到网络中进行测试,或者根据网络要求,自行绘制数字然后处理成网络要求的格式进行测试。

注:MNIST的网络处理要求28x28x1的图片输入,我们在数据集还原出来的jpg文件为28x28x3,因此在任选图片进行预测时,需要先将图片进行处理。

你可能感兴趣的:(mnist,torch加载fashion)