手写字体识别(1) 准备数据集

目录

  • 数据下载及提取
  • 数据下载
    • ubyte文件及提取
    • csv文件及提取
  • 数据集预处理

github地址:
https://github.com/Huyf9/mnist_pytorch/

数据下载及提取

为了规范代码与数据集,因此我们按如下格式创建项目:

mnist
 |—— dataset
     |__ train
     |__ test

mnist
dataset
train
test

数据下载

ubyte文件及提取

ubyte格式数据下载地址:
http://yann.lecun.com/exdb/mnist/

我们需要下载这四个文件:

train-images-idx3-ubyte.gz #training set images
train-labels-idx1-ubyte.gz #training set labels
t10k-images-idx3-ubyte.gz #test set images
t10k-labels-idx1-ubyte.gz #test set labels

这个文件格式真麻烦!这辈子应该也不会再跟它打交道了,直接上代码:

from PIL import Image
import numpy as np
from tqdm import tqdm

def convert(image_path, label_path, n):

    f_images = open(image_path[0], 'rb')
    f_labels = open(label_path[0], 'rb')
    f_out = open(label_path[1], 'w')  # 标签路径

    f_images.read(16)
    f_labels.read(1)

    images = []
    labels = []
    for i in range(n):
        image = []
        labels.append(ord(f_labels.read(1)))
        for j in range(28*28):
            image.append(ord(f_images.read(1)))
        images.append(image)
    idx = 0
    for image in tqdm(images):
        img = Image.fromarray(np.array(image).reshape((28, 28))).convert('L')
        if idx >= 10000:
            img.save(image_path[1] + '00' + str(idx) + '.png')
        elif idx >= 1000:
            img.save(image_path[1] + '000' + str(idx) + '.png')
        elif idx >= 100:
            img.save(image_path[1] + '0000' + str(idx) + '.png')
        elif idx >= 10:
            img.save(image_path[1] + '00000' + str(idx) + '.png')
        else:
            img.save(image_path[1] + '000000' + str(idx) + '.png')

        idx += 1

    for label in labels:
        f_out.write(str(label) + '\n')

    f_images.close()
    f_labels.close()
    f_out.close()


train_image_path = ['train-images.idx3-ubyte', 'dataset\\train\\']  # 训练集图片读取路径;训练集图片保存路径
train_label_path = ['train-labels.idx1-ubyte', 'dataset\\train.txt']  # 训练集标签读取路径;训练集标签保存路径
test_image_path = ['t10k-images.idx3-ubyte', 'dataset\\test\\']  # 测试集图片读取路径;测试集图片保存路径
test_label_path = ['t10k-labels.idx1-ubyte', 'dataset\\test.txt']  # 测试集标签读取路径;测试集标签保存路径
convert(train_image_path, train_label_path, 60000)
print('Generate the train sets done!')
convert(test_image_path, test_label_path, 10000)
print('Generate the test sets done!')

csv文件及提取

csv格式数据下载地址:
https://pjreddie.com/projects/mnist-in-csv/
将train set与test set下载。csv文件中的图片格式为:

label, pix-11, pix-12, pix-13, …

第一列为图片的标签。其中pix-ij为第i行第j列的像素值。

我们从csv文件中读取label并保存为txt文件, 读取像素值转化为png图片并保存。

代码如下:

import numpy
import numpy as np
from tqdm import tqdm
from PIL import Image

'''
标签直接保存到txt文件中
图片先reshape成28x28大小,再将其名称长度统一为7,方便后续按顺序读取
名称格式为00....+idx.png
idx表示为第idx张图片,前面的0的数量为使其长度为7时所需的数量
'''

def convert(mnist, save_path):
    idx = 0

    f_tr_label = open(r'dataset\\train_label.txt', 'w')  # 保存train标签文件
    f_test_label = open(r'dataset\\test_label.txt', 'w')  # 保存test标签文件
    f_label = [f_tr_label, f_test_label]

    for i in range(len(mnist)):
        mnist_docs = open(mnist[i], 'r').readlines()
        for mnist_doc in tqdm(mnist_docs):
            mnist_doc = mnist_doc.strip().split(',')  # csv文件用逗号分隔数据,因此利用split()来以逗号划分成列表
            f_label[i].write(mnist_doc[0] + '\n')

            '''
            由于读取的pixel是字符型,需要将其转为整型
            利用map()函数来进行实现
            map()的语法规则为
            map(function, iterable, ...) --function为一个函数 --iterable为一个或多个序列
            因此使用map(int list_name)就可以将一个列表中的元素转为整型
            但是其返回的是迭代器,因此还需要用list()将其再转化为列表
            '''
            # print(mnist_doc[1:])
            img = list(map(int, mnist_doc[1:]))
            img = np.array(img).reshape((28, 28))  # mnist图片尺寸为28 x 28的
            img = Image.fromarray(img).convert('L')  # 将numpy矩阵转为image格式并且转化为灰度图,方便保存图片

            if idx >= 10000:
                img.save(save_path[i] + '00' + str(idx) + '.png')
            elif idx >= 1000:
                img.save(save_path[i] + '000' + str(idx) + '.png')
            elif idx >= 100:
                img.save(save_path[i] + '0000' + str(idx) + '.png')
            elif idx >= 10:
                img.save(save_path[i] + '00000' + str(idx) + '.png')
            else:
                img.save(save_path[i] + '000000' + str(idx) + '.png')

            idx += 1

        print('done!')


mnist = ['mnist_train.csv', 'mnist_test.csv']
save_path = [r'dataset\\train\\', r'dataset\\test\\']
convert(mnist, save_path)

无注释纯净版!

数据集预处理

在获取到mnist的图片与标签之后,我们需要对数据集进行一个预处理,来获取数据集的图片路径,方便后续进行数据加载。同时需要将训练集划分为训练部分和验证部分,每一轮训练结束后用验证部分来验证模型性能。

首先在根目录下创建mnist_annotation.py文件,这个文件运行结束后会在dataset目录下生成

train.txt # 保存训练图片路径
test.txt # 保存测试图片路径
val.txt # 保存验证图片路径
train_label.txt # 保存训练标签
test_label.txt # 保存测试标签
val_label.txt # 保存验证标签

代码如下:

import os
from tqdm import tqdm

from typing import List
# 训练集与验证集比例 = 9 : 1
# 由于训练集与测试集已经区分开,所以不需要再划分
trainval_percent = 0.9

def split_sets(train_path, test_path):
    f_train = open(r'dataset\train.txt', 'w')
    f_test = open(r'dataset\test.txt', 'w')
    f_val_label = open(r'dataset\val_label.txt', 'w')
    f_train_label = open(r'dataset\tr_label.txt', 'w')

    # 将 dataset\\test 下的图片用os.listdir()函数获取到文件路径并写入txt文件
    test_pic_paths = os.listdir(test_path)
    for test_pic_path in tqdm(test_pic_paths):
        f_test.write(test_pic_path + '\n')
    print('Generate test.txt done!')

    # 将 dataset\\train 下的图片与标签划分为训练与验证,并依次存入txt文件
    train_pic_path = os.listdir(train_path[0])
     
    num = len(train_pic_path)
    train_num = int(num * trainval_percent)

    for i in tqdm(train_pic_path[0:train_num]):
        f_train.write(i + '\n')
    print('Generate train.txt done!')

    with open(train_path[1], 'r+') as f:
        train_labels = f.readlines()
        for train_label in tqdm(train_labels[0:train_num]):
            f_train_label.write(train_label)
        print('Generate tr_label.txt done!')

        for val_label in train_labels[train_num:]:
            f_val_label.write(val_label)
        print('Generate val_label.txt done!')

    f.close()
    f_train.close()
    f_test.close()
    f_val_label.close()
    f_train_label.close()


train_path = ['dataset\\train', 'dataset\\train_label.txt']
test_path = 'dataset\\test'
split_sets(train_path, test_path)

你可能感兴趣的:(手写字体识别,python,机器学习,人工智能)