Numpy学习(3):将mnist数据文件读入到数据结构(numpy数组)中

前言:

mnist数据集是一个手写数字识别库,用于机器学习和深度学习的分类问题,同大多数标准化图像数据库一样,官网提供的文件并不是原始图像,而是经过数值化的二进制文件。比如:cifar10库的二进制文件解析出来后是一个字典,字典中又包含了代表图片特征的二维数组。cifar10库的解释详见:点击查看博客。将图像特征数字化保存到数组里面有助于 提高ML/DL框架的计算效率。


一般的,数据集的官网都会详细讲解数据集的结构,和解析方式。


同样的,mnist数据集也有结构,其解析方式官网已经解释的非常清楚了,详见官网。我们就根据其提供的结构使用python语言将其数据文件读入到numpy数组中。其实,很多框架(如caffe,tensorflow)在将mnist数据集作为入门案例时,读数据的思路大致相同,只不过代码的实现方式不同。


正文:

1,数据集的内容:
下载文件有4个,都是二进制文件
train-images-idx3-ubyte.gz:  training set images (9912422 bytes) 
train-labels-idx1-ubyte.gz:  training set labels (28881 bytes) 
t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes) 
t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)

这些二进制文件的格式是:IDX,官网做如下解释:

Numpy学习(3):将mnist数据文件读入到数据结构(numpy数组)中_第1张图片

说的很明显啦,一个数据文件固定分3部分,可以根据第2部分定义的维度,将第3部分的数据解析成相应的numpy高维数组。

2,训练集数据文件的结构( train-images-idx3-ubyte ),如下图:

Numpy学习(3):将mnist数据文件读入到数据结构(numpy数组)中_第2张图片

如上图所示,这个二进制文件的内容按照以‘位’为单位,前16字节(128位)是4个整形数字,每个数字占4个字节(32位),根据这4个数字的解释我们知道分别是:魔数(可以理解为文件id),图片数,图片纵向像素数,图片横向像素数。接下来,每一个字节(8位)代表一个像素值(0-255,8位二进制数正好可以表示0-255之间的十进制 数),然后每784个字节就构成了一张图片的全部像素点。

3,训练集标签文件的结构( train-labels-idx1-ubyte ),如下图:

Numpy学习(3):将mnist数据文件读入到数据结构(numpy数组)中_第3张图片

如上图所示,跟数据文件一样,标签文件具有相似的结构,60000张图像对应60000个标签。

4,python代码解析二进制文件
解析时用到了python的内置库:struct库,该库可以按照指定的格式将常用数据类型转换成二进制,也可以按指定格式从二进制中解析出相应信息。代码如下:
'''
    使用python解析二进制文件
'''
import numpy as np
import struct

def loadImageSet(filename):

    binfile = open(filename, 'rb') # 读取二进制文件
    buffers = binfile.read()

    head = struct.unpack_from('>IIII', buffers, 0) # 取前4个整数,返回一个元组

    offset = struct.calcsize('>IIII')  # 定位到data开始的位置
    imgNum = head[1]
    width = head[2]
    height = head[3]

    bits = imgNum * width * height  # data一共有60000*28*28个像素值
    bitsString = '>' + str(bits) + 'B'  # fmt格式:'>47040000B'

    imgs = struct.unpack_from(bitsString, buffers, offset) # 取data数据,返回一个元组

    binfile.close()
    imgs = np.reshape(imgs, [imgNum, width * height]) # reshape为[60000,784]型数组

    return imgs,head


def loadLabelSet(filename):

    binfile = open(filename, 'rb') # 读二进制文件
    buffers = binfile.read()

    head = struct.unpack_from('>II', buffers, 0) # 取label文件前2个整形数

    labelNum = head[1]
    offset = struct.calcsize('>II')  # 定位到label数据开始的位置

    numString = '>' + str(labelNum) + "B" # fmt格式:'>60000B'
    labels = struct.unpack_from(numString, buffers, offset) # 取label数据

    binfile.close()
    labels = np.reshape(labels, [labelNum]) # 转型为列表(一维数组)

    return labels,head


if __name__ == "__main__":
    file1= 'E:/pythonProjects/dataSets/mnist/train-images.idx3-ubyte'
    file2= 'E:/pythonProjects/dataSets/mnist/train-labels.idx1-ubyte'

    imgs,data_head = loadImageSet(file1)
    print('data_head:',data_head)
    print(type(imgs))
    print('imgs_array:',imgs)
    print(np.reshape(imgs[1,:],[28,28])) #取出其中一张图片的像素,转型为28*28,大致就能从图像上看出是几啦

    print('----------我是分割线-----------')

    labels,labels_head = loadLabelSet(file2)
    print('labels_head:',labels_head)
    print(type(labels))
    print(labels)


你可能感兴趣的:(Python)