mnist 手写数据解析

转载请注明出处https://blog.csdn.net/EatAppleS/article/details/90172847

数据库的下载地址:http://yann.lecun.com/exdb/mnist/

有4个文件,训练集的图片和标签,

train-images.idx3-ubyte

train-labels.idx1-ubyte

测试集的图片和标签

t10k-images.idx3-ubyte

t10k-labels.idx1-ubyte

数据库及解析代码的下载地址如下:https://download.csdn.net/download/eatapples/11175660

 

解析代码如下:

import numpy as np
import struct
import cv2
train_images_idx3_ubyte_file =r'./train-images.idx3-ubyte'
train_labels_idx1_ubyte_file = r'./train-labels.idx1-ubyte'

test_images_idx3_ubyte_file =r'./t10k-images.idx3-ubyte'
test_labels_idx1_ubyte_file =r'./t10k-labels.idx1-ubyte'


def decode_images(imgPath):
    bin_data = open(imgPath, 'rb').read()
    offset = 0
    magic_number, num_images, num_rows, num_cols = struct.unpack_from('>iiii', bin_data, offset)
    print('img num %d img rows %d img cols %d' % (num_images, num_rows, num_cols))

    image_size = num_rows * num_cols
    offset += struct.calcsize('>iiii')
    fmt_image = '>' + str(image_size) + 'B'
    images = np.empty((num_images, num_rows, num_cols))
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('deocde img %d' % (i + 1))
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
        offset += struct.calcsize(fmt_image)
    return images

def decode_labels(labelPath):
    bin_data = open(labelPath, 'rb').read()
    offset = 0
    magic_number, num_images = struct.unpack_from('>ii', bin_data, offset)
    offset += struct.calcsize('>ii')
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('deocde label %d' % (i + 1))
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels


if __name__ == '__main__':
    train_images = decode_images(train_images_idx3_ubyte_file)
    train_labels = decode_labels(train_labels_idx1_ubyte_file)

    with open('./train.txt', 'w') as f:
        for i in range(len(train_images)):
           # cv2.imshow("a",train_images[i])
           # cv2.waitKey()
            cv2.imwrite('./trainData/'+ str(i) + '.jpg ',train_images[i])
            strLine = './trainData/'+ str(i) + '.jpg ' + str(int(train_labels[i])) + '\n'
            f.write(strLine)

    test_images = decode_images(test_images_idx3_ubyte_file)
    test_labels = decode_labels(test_labels_idx1_ubyte_file)

    with open('./test.txt', 'w') as f:
        for i in range(len(test_images)):
            # cv2.imshow("a",train_images[i])
            # cv2.waitKey()
            cv2.imwrite('./testData/' + str(i) + '.jpg ', test_images[i])
            strLine = './testData/' + str(i) + '.jpg ' + str(int(test_labels[i])) + '\n'
            f.write(strLine)

    print('ok')

 

你可能感兴趣的:(深度学习)