MNIST 网站
http://yann.lecun.com/exdb/mnist/
四个文件
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
下下来后 解压
$ gunzip *.gz
t10k-images-idx3-ubyte
train-images-idx3-ubyte
t10k-labels-idx1-ubyte
train-labels-idx1-ubyte
解压后会生成上面的四个文件
文件的格式
There are 4 files:
train-images-idx3-ubyte: training set images
train-labels-idx1-ubyte: training set labels
t10k-images-idx3-ubyte: test set images
t10k-labels-idx1-ubyte: test set labels
The training set contains 60000 examples, and the test set 10000 examples.
The first 5000 examples of the test set are taken from the original NIST training set. The last 5000 are taken from the original NIST test set. The first 5000 are cleaner and easier than the last 5000.
TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 10000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 10000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
图像文件的前16个字节是头,包含了4个字节的幻数,4个字节表示图像数量,4个字节表示单个图像的行数,4个字节表示单个图像的列数.
标记文件的前8个字节是头,包含了4个字节的幻数,4个字节表示标记数量.
下面读取文件
from __future__ import division
from __future__ import print_function
#gunzip *.gz
#http://yann.lecun.com/exdb/mnist/
import os
import sys
import struct
file_list = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
def create_path(path):
if not os.path.isdir(path):
os.makedirs(path)
def get_file_full_name(path, name):
create_path(path)
if path[-1] == "/":
full_name = path + name
else:
full_name = path + "/" + name
return full_name
def read_mnist(file_name):
file_path = "/home/your/data/path"
full_path = get_file_full_name(file_path, file_name)
file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
return file_object
def get_file_header_data(file_name, header_len, unpack_str):
f = read_mnist(file_name)
raw_header = f.read(header_len)
header_data = struct.unpack(unpack_str, raw_header)
return header_data
def show_images_file_header(file_name):
show_file_header(file_name, 16, ">4I")
def show_labels_file_header(file_name):
show_file_header(file_name, 8, ">2I")
def show_file_header(file_name, header_len, unpack_str):
header_data = get_file_header_data(file_name, header_len, unpack_str)
print("%s header data:%s" % (file_name, header_data))
def show_mnist_file_header():
train_images_file_name = file_list[0]
show_images_file_header(train_images_file_name)
test_images_file_name = file_list[2]
show_images_file_header(test_images_file_name)
train_labels_file_name = file_list[1]
show_labels_file_header(train_labels_file_name)
test_labels_file_name = file_list[3]
show_labels_file_header(test_labels_file_name)
def run():
show_mnist_file_header()
run()
输出
train-images-idx3-ubyte header data:(2051, 60000, 28, 28)
t10k-images-idx3-ubyte header data:(2051, 10000, 28, 28)
train-labels-idx1-ubyte header data:(2049, 60000)
t10k-labels-idx1-ubyte header data:(2049, 10000)
下面我问读取一张图片 并且展示一张图片和它的标记
from __future__ import division
from __future__ import print_function
#gunzip *.gz
#http://yann.lecun.com/exdb/mnist/
import os
import sys
import struct
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
file_list = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
def create_path(path):
if not os.path.isdir(path):
os.makedirs(path)
def get_file_full_name(path, name):
create_path(path)
if path[-1] == "/":
full_name = path + name
else:
full_name = path + "/" + name
return full_name
def read_mnist(file_name):
file_path = "/home/your/data/path"
full_path = get_file_full_name(file_path, file_name)
file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
return file_object
def get_file_header_data(file_obj, header_len, unpack_str):
raw_header = file_obj.read(header_len)
header_data = struct.unpack(unpack_str, raw_header)
return header_data
def show_images_file_header(file_name):
show_file_header(file_name, 16, ">4I")
def show_labels_file_header(file_name):
show_file_header(file_name, 8, ">2I")
def show_file_header(file_name, header_len, unpack_str):
file_obj = read_mnist(file_name)
header_data = get_file_header_data(file_obj, header_len, unpack_str)
show_file_header_data(file_name, header_data)
file_obj.close()
def show_mnist_file_header():
train_images_file_name = file_list[0]
show_images_file_header(train_images_file_name)
test_images_file_name = file_list[2]
show_images_file_header(test_images_file_name)
train_labels_file_name = file_list[1]
show_labels_file_header(train_labels_file_name)
test_labels_file_name = file_list[3]
show_labels_file_header(test_labels_file_name)
def read_a_image(file_object):
img = file_object.read(28*28)
tp = struct.unpack(">784B",img)
image = np.asarray(tp)
image = image.reshape((28,28))
#image = image.astype(np.float64)
plt.imshow(image,cmap = plt.cm.gray)
plt.show()
def read_a_label(file_object):
img = file_object.read(1)
tp = struct.unpack(">B",img)
print("the label is :%s" % tp[0])
def show_file_header_data(file_name,header_data):
print("%s header data:%s" % (file_name, header_data))
def show_a_image():
images_file_name = file_list[0]
labels_file_name = file_list[1]
images_file = read_mnist(images_file_name)
header_data = get_file_header_data(images_file, 16, ">4I")
show_file_header_data(images_file_name, header_data)
labels_file = read_mnist(labels_file_name)
header_data = get_file_header_data(labels_file, 8, ">2I")
show_file_header_data(labels_file_name, header_data)
read_a_image(images_file)
read_a_label(labels_file)
def run():
#show_mnist_file_header()
show_a_image()
run()
输出
train-images-idx3-ubyte header data:(2051, 60000, 28, 28)
train-labels-idx1-ubyte header data:(2049, 60000)
the label is :5
恩 图片和标记一样是5
然后我们修改成能自动生成批数据
from __future__ import division
from __future__ import print_function
#gunzip *.gz
#http://yann.lecun.com/exdb/mnist/
import os
import sys
import struct
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
file_list = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
def show_images_file_header(file_name):
show_file_header(file_name, 16, ">4I")
def show_labels_file_header(file_name):
show_file_header(file_name, 8, ">2I")
def show_file_header(file_name, header_len, unpack_str):
file_obj = read_mnist(file_name)
header_data = get_file_header_data(file_obj, header_len, unpack_str)
show_file_header_data(file_name, header_data)
file_obj.close()
def show_mnist_file_header():
train_images_file_name = file_list[0]
show_images_file_header(train_images_file_name)
test_images_file_name = file_list[2]
show_images_file_header(test_images_file_name)
train_labels_file_name = file_list[1]
show_labels_file_header(train_labels_file_name)
test_labels_file_name = file_list[3]
show_labels_file_header(test_labels_file_name)
def show_a_image(file_object):
image = read_a_image(images_file)
image = np.asarray(tp)
image = image.reshape((28,28))
plt.imshow(image,cmap = plt.cm.gray)
plt.show()
def show_a_lebel(file_object):
tp = read_a_label(file_object)
print("the label is :%s" % tp)
def show_file_header_data(file_name,header_data):
print("%s header data:%s" % (file_name, header_data))
def show_a_image():
images_file_name = file_list[0]
labels_file_name = file_list[1]
images_file = read_mnist(images_file_name)
header_data = get_file_header_data(images_file, 16, ">4I")
show_file_header_data(images_file_name, header_data)
labels_file = read_mnist(labels_file_name)
header_data = get_file_header_data(labels_file, 8, ">2I")
show_file_header_data(labels_file_name, header_data)
show_a_image(images_file)
read_a_label(labels_file)
def create_path(path):
if not os.path.isdir(path):
os.makedirs(path)
def get_file_full_name(path, name):
create_path(path)
if path[-1] == "/":
full_name = path + name
else:
full_name = path + "/" + name
return full_name
def read_mnist(file_name):
file_path = "/home/your/data/path"
full_path = get_file_full_name(file_path, file_name)
file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
return file_object
def get_file_header_data(file_obj, header_len, unpack_str):
raw_header = file_obj.read(header_len)
header_data = struct.unpack(unpack_str, raw_header)
return header_data
def read_a_image(file_object):
raw_img = file_object.read(28*28)
img = struct.unpack(">784B",raw_img)
return img
def read_a_label(file_object):
raw_label = file_object.read(1)
label = struct.unpack(">B",raw_label)
return label
def generate_a_batch(images_file_name,labels_file_name,batch_size=8):
images_file = read_mnist(images_file_name)
header_data = get_file_header_data(images_file, 16, ">4I")
#show_file_header_data(images_file_name, header_data)
labels_file = read_mnist(labels_file_name)
header_data = get_file_header_data(labels_file, 8, ">2I")
#show_file_header_data(labels_file_name, header_data)
while True:
images = []
labels = []
for i in range(100):
try:
image = read_a_image(images_file)
label = read_a_label(labels_file)
images.append(image)
labels.append(label)
except Exception as err:
print(err)
break
yield images,labels
def get_train_data_generator():
images_file_name = file_list[0]
labels_file_name = file_list[1]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator-
def get_test_data_generator():
images_file_name = file_list[2]
labels_file_name = file_list[3]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator
def get_test_data_generator():
images_file_name = file_list[2]
labels_file_name = file_list[3]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator-
def get_a_batch(data_generator):
if sys.version >'3':
batch_img, batch_labels = data_generator.__next__()
else:
batch_img, batch_labels = data_generator.next()
return batch_img,batch_labels
def generate_test_batch():
data_generator = get_test_data_generator()
count = 1
while count:
batch_img,batch_labels = get_a_batch(data_generator)
if not batch_img and not batch_labels:
break
batch_img = np.array(batch_img)
batch_labels = np.array(batch_labels)
print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))
count +=1
def generate_train_batch():
epoch = 0
while epoch<10:
epoch += 1
data_generator = get_train_data_generator()
count = 1
while count:
batch_img,batch_labels = get_a_batch(data_generator)
if not batch_img and not batch_labels:
break
batch_img = np.array(batch_img)
batch_labels = np.array(batch_labels)
print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))
count +=1
def run():
generate_train_batch()
generate_test_batch()
run()
上面的格式里好多没有用的代码 把没有用的代码删掉
我们得到
from __future__ import division
from __future__ import print_function
#gunzip *.gz
#http://yann.lecun.com/exdb/mnist/
import os
import sys
import struct
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
file_list = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
def create_path(path):
if not os.path.isdir(path):
os.makedirs(path)
def get_file_full_name(path, name):
create_path(path)
if path[-1] == "/":
full_name = path + name
else:
full_name = path + "/" + name
return full_name
def read_mnist(file_name):
file_path = "/home/your/data/path"
full_path = get_file_full_name(file_path, file_name)
file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
return file_object
def get_file_header_data(file_obj, header_len, unpack_str):
raw_header = file_obj.read(header_len)
header_data = struct.unpack(unpack_str, raw_header)
return header_data
def read_a_image(file_object):
raw_img = file_object.read(28*28)
img = struct.unpack(">784B",raw_img)
return img
def read_a_label(file_object):
raw_label = file_object.read(1)
label = struct.unpack(">B",raw_label)
return label
def generate_a_batch(images_file_name,labels_file_name,batch_size=8):
images_file = read_mnist(images_file_name)
header_data = get_file_header_data(images_file, 16, ">4I")
labels_file = read_mnist(labels_file_name)
header_data = get_file_header_data(labels_file, 8, ">2I")
while True:
images = []
labels = []
for i in range(100):
try:
image = read_a_image(images_file)
label = read_a_label(labels_file)
images.append(image)
labels.append(label)
except Exception as err:
print(err)
break
yield images,labels
def get_train_data_generator():
images_file_name = file_list[0]
labels_file_name = file_list[1]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator
def get_test_data_generator():
images_file_name = file_list[2]
labels_file_name = file_list[3]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator
def get_a_batch(data_generator):
if sys.version >'3':
batch_img, batch_labels = data_generator.__next__()
else:
batch_img, batch_labels = data_generator.next()
return batch_img,batch_labels
def generate_test_batch():
data_generator = get_test_data_generator()
count = 1
while count:
batch_img,batch_labels = get_a_batch(data_generator)
if not batch_img and not batch_labels:
break
batch_img = np.array(batch_img)
batch_labels = np.array(batch_labels)
print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))
count +=1
def generate_train_batch():
epoch = 0
while epoch<10:
epoch += 1
data_generator = get_train_data_generator()
count = 1
while count:
batch_img,batch_labels = get_a_batch(data_generator)
if not batch_img and not batch_labels:
break
batch_img = np.array(batch_img)
batch_labels = np.array(batch_labels)
print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))
count +=1
def run():
generate_train_batch()
generate_test_batch()
run()
输出好长,输出就不贴上来了 以上代码兼容了Python2和Python3