mnist数据集是一个手写数字识别库,用于机器学习和深度学习的分类问题,同大多数标准化图像数据库一样,官网提供的文件并不是原始图像,而是经过数值化的二进制文件。比如:cifar10库的二进制文件解析出来后是一个字典,字典中又包含了代表图片特征的二维数组。cifar10库的解释详见:点击查看博客。将图像特征数字化保存到数组里面有助于 提高ML/DL框架的计算效率。
一般的,数据集的官网都会详细讲解数据集的结构,和解析方式。
同样的,mnist数据集也有结构,其解析方式官网已经解释的非常清楚了,详见官网。我们就根据其提供的结构使用python语言将其数据文件读入到numpy数组中。其实,很多框架(如caffe,tensorflow)在将mnist数据集作为入门案例时,读数据的思路大致相同,只不过代码的实现方式不同。
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
'''
使用python解析二进制文件
'''
import numpy as np
import struct
def loadImageSet(filename):
binfile = open(filename, 'rb') # 读取二进制文件
buffers = binfile.read()
head = struct.unpack_from('>IIII', buffers, 0) # 取前4个整数,返回一个元组
offset = struct.calcsize('>IIII') # 定位到data开始的位置
imgNum = head[1]
width = head[2]
height = head[3]
bits = imgNum * width * height # data一共有60000*28*28个像素值
bitsString = '>' + str(bits) + 'B' # fmt格式:'>47040000B'
imgs = struct.unpack_from(bitsString, buffers, offset) # 取data数据,返回一个元组
binfile.close()
imgs = np.reshape(imgs, [imgNum, width * height]) # reshape为[60000,784]型数组
return imgs,head
def loadLabelSet(filename):
binfile = open(filename, 'rb') # 读二进制文件
buffers = binfile.read()
head = struct.unpack_from('>II', buffers, 0) # 取label文件前2个整形数
labelNum = head[1]
offset = struct.calcsize('>II') # 定位到label数据开始的位置
numString = '>' + str(labelNum) + "B" # fmt格式:'>60000B'
labels = struct.unpack_from(numString, buffers, offset) # 取label数据
binfile.close()
labels = np.reshape(labels, [labelNum]) # 转型为列表(一维数组)
return labels,head
if __name__ == "__main__":
file1= 'E:/pythonProjects/dataSets/mnist/train-images.idx3-ubyte'
file2= 'E:/pythonProjects/dataSets/mnist/train-labels.idx1-ubyte'
imgs,data_head = loadImageSet(file1)
print('data_head:',data_head)
print(type(imgs))
print('imgs_array:',imgs)
print(np.reshape(imgs[1,:],[28,28])) #取出其中一张图片的像素,转型为28*28,大致就能从图像上看出是几啦
print('----------我是分割线-----------')
labels,labels_head = loadLabelSet(file2)
print('labels_head:',labels_head)
print(type(labels))
print(labels)