(1) Mnist数据集: 简介与读取

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 是 Yann Lecun 大佬整理的手写数字数据集,分为以下四个部分:

dataset name details
Training set images train-images-idx3-ubyte.gz 60,000 个样本的像素值
Training set labels train-labels-idx1-ubyte.gz 60,000 个标签
Test set images t10k-images-idx3-ubyte.gz 10,000 个样本的像素值
Test set labels t10k-labels-idx1-ubyte.gz 10,000 个标签

数据读取

import gzip
import struct

def read_data(label_url,image_url):
    with gzip.open(label_url) as flbl:
        magic, num = struct.unpack(">II",flbl.read(8))
        label = np.fromstring(flbl.read(),dtype=np.int8)
    with gzip.open(image_url,'rb') as fimg:
        magic, num, rows, cols = struct.unpack(">IIII",fimg.read(16))
        image = np.fromstring(fimg.read(),dtype=np.uint8).reshape(len(label),rows,cols)
    return (label, image)

获取Train和Test

输入是 ohe 标志,输出是像素值与标签值构成的 tuple

def get_train(ohe=True):
    (train_lbl, train_img) = read_data('DataSet/Mnist/train-labels-idx1-ubyte.gz','DataSet/Mnist/train-images-idx3-ubyte.gz')
    train_img = train_img.reshape((*train_img.shape, 1))  # 添加通道维度
    train_img = preprocessing_img(train_img)  # 归一化处理
    if ohe:
        class_num = len(np.unique(train_lbl))
        train_lbl = np_utils.to_categorical(train_lbl, num_classes=class_num)  # 对标签进行 one hot 编码
    return train_img, train_lbl

def get_test(ohe=True):
    (val_lbl, val_img) = read_data('DataSet/Mnist/t10k-labels-idx1-ubyte.gz','DataSet/Mnist/t10k-images-idx3-ubyte.gz')
    val_img = val_img.reshape((*val_img.shape, 1))  # 添加通道维度
    val_img = preprocessing_img(val_img)
    if ohe:
        class_num = len(np.unique(val_lbl))
        val_lbl = np_utils.to_categorical(val_lbl, num_classes=class_num)  # 对标签进行 one hot 编码
    return val_img, val_lbl

你可能感兴趣的:((1) Mnist数据集: 简介与读取)