基于KNN的手写字符识别

整理了一下自己之前做过的手写字符识别的资料,分享出来供大家学习交流,后续可能还会分享一些其他方法进行手写字符识别的资料,敬请期待~

  • 一、任务和设计思路
  • 二、KNN算法实现
    • 1、KNN算法简介
    • 2、简单的KNN代码
    • 3、使用sklearn的KNN分类器
    • 4、Kd_tree介绍
      • (1)Kd_tree的构造
      • (2)Kd_tree的查询

一、任务和设计思路

目的是要对手写字符的图片进行识别, 使用的是Chars74K 数据集,其中有0-9和a-z共36类的图片。设计思路如下:(1)先将彩色图片转换为二值图,再将其按照原来的像素形式(如32*32)存放在txt文件中。(2)将转换后的文件划分为训练组和测试组。(3)使用KNN算法对数据进行分类,通过实验结果选取合适的K值使识别率最高。

二、KNN算法实现

1、KNN算法简介

通过先前输入的训练数据确定了基本的类别,当新的数据输入时,通过计算与训练数据的广义上的距离,并通过设置临近的k个样本点与输入数据距离的权重,综合判断出输入数据的类别。所以关键在于距离的定义,k值得选择以及权重的设定。对于距离来说,我们经常使用的是欧式距离,k的选择则需要根据实验的结果进行选取,k太小容易造成过拟合,太大则会使分类结果过于模糊,效果不佳。权重的选择则要根据实际问题进行选择。

2、简单的KNN代码

首先分享一个很简单的KNN代码,直接上代码!

import csv
import random
# 读取
with open('Prostate_Cancer.csv', 'r') as file:
    reader = csv.DictReader(file)
    datas = [row for row in reader]  # csv读出的都是字符串
# 分组
n = len(datas) // 3
random.shuffle(datas)  # 貌似木有用
test_set = datas[0:n]
train_set = datas[n:]

# KNN 距离
def distance(d1, d2):

    res = 0
    for key in ("radius","texture","perimeter","area","smoothness","compactness","symmetry","fractal_dimension"):

        res += (float(d1[key])-float(d2[key])) ** 2

        return res ** 0.5

def knn(data):
    # 求距离
    res = [
        {"result": train["diagnosis_result"], "distance": distance(data, train)}
        for train in train_set
    ]
    # 和每一个训练集中的数据进行比较
    # 升序排序
    res = sorted(res, key=lambda item: item["distance"])

    # 取前K个
    res2 = res[0:K]

    # 加权平均
    result = {'B': 0, 'M': 0}
    # 总的距离
    sum = 0
    for r in res2:
        sum += r['distance'] + 0.000001
    for r in res2:
        result[r["result"]] += 1-r['distance']/sum
    # print(result)
    # print(data['diagnosis_result'])

    if result['B'] > result['M']:
        return 'B'
    else:
        return 'M'
# 测试阶段

for K in range(1, 20):
    correct = 0
    for test in test_set:
        result = test['diagnosis_result']
        result2 = knn(test)
        if result == result2:
            correct += 1
    print("K为", K, "时")
    print("准确率:{:.2f}%".format(100 * correct / len(test_set)))

这个代码是简单的两个类别的分类问题,只有100条数据,入手比较容易。
这个代码的步骤分为4步,使用的是欧氏距离,k的值可以自己设定,我尝试了k=1-20所对应的结果,方便对照选择最优的k值。
1、 计算距离(欧式距离或者其他距离)
2、 升序排列
3、 取前k个
4、 加权平均
这是我在B站上看到的一个视频中的代码,自己敲了一下,加了一个k的循环,方便直观了解正确率和k的关系,并且对KNN有了进一步的理解,大家如果感兴趣,可以去看看,讲的很好。链接: link.

3、使用sklearn的KNN分类器

那么有了对KNN的一些理解后,回头来看看怎么用分类器来实现多种类的分类,其实sklearn就是对2节的代码做了封装和拓展,基本思路不变(源代码都是出自大佬之手,难搞啊,我就是那个意思,你们懂的)。
首先,在使用KNN进行分类的时候,先将手写字符的图片按照像素划分为32 * 32的矩阵(这个在将图像二值化和转换为txt时可以自己决定矩阵的大小,如也可以变为 64 * 32),通过向量转换img2vector 变为一个一维数组,即将原图的矩阵按照行压缩成一行,将所有数据这样处理后,构成一个[m,n]的二维数组,记作向量图。m即为图片的数目,n=32*32。在压缩的过程中,将文件名中的类别按照顺序全部存放在一个一维数组中,记作标签。它与向量化后的二维数组的行一一对应,即标签中的第一个元素对应向量图中第一行元素。采用KNN的分类器进行训练,
其次,采用同样方法将测试集中的数据进行拆分,并使用训练的分类器对其类别进行预测,将预测结果与预测组标签中的类别进行验证,记录正确和错误,得到分类的准确度。
再次,我们明确了思路,再来看看分类器的使用方法
这个分类器的使用如下:

sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, weights=’uniform’,algorithm=’auto’, leaf_size=30, p=2, metric=’minkowski’, metric_params=None,n_jobs=None, **kwargs)

具体的参数可以去看这篇博客链接: link.,写的很全。
我们主要关注这个

algorithm=’auto’/‘brute’/‘kd_tree’/‘ball_tree’

‘brute’对应第一种线性扫描;即将想要预测的数据与训练数据中的所有数据求距离,在从中选取前k个进行加权计算,得到类别。若总共有N个结点,计算复杂度为O(N * k),k为选用的参考点的数目。
‘kd_tree’对应第二种kd树实现;这个算法可以大大减少计算的复杂度,若数据是随机的,N个节点的计算复杂度仅为O(logN)。
‘ball_tree’对应第三种的球树实现;这个未做了解,后续学到了会补上。
‘auto’则会在上面三种算法中做权衡,选择一个拟合最好的最优算法。
参数理解了,输入输出也明白了,那就上代码!!!

import numpy as np
from os import listdir
from sklearn.neighbors import KNeighborsClassifier as KNN


"""
函数说明:32x32的二进制图像转换为1x1024向量
"""
def img2vector(filename):
    # 创建1x1024零向量
    returnVect = np.zeros((1, 1024))
    # 打开文件
    fr = open(filename)
    # 按行读取
    for i in range(32):
        # 读一行数据
        lineStr = fr.readline()
        # 每一行的前32个元素依次添加到returnVect中
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])
    # 返回转换后的1x1024向量
    return returnVect


"""
函数说明:手写数字分类测试
"""
def handwritingClassTest():
    # 训练集的Labels
    hwLabels = []
    # 返回trainingDigits目录下的文件名
    trainingFileList = listdir('trainingDigits')
    # 返回文件夹下文件的个数
    m = len(trainingFileList)
    # 初始化训练的Mat矩阵,训练集
    trainingMat = np.zeros((m, 1024))
    # 从文件名中解析出训练集的类别
    for i in range(m):
        #获得文件的名字
        fileNameStr = trainingFileList[i]
        #获得分类的数字
        classNumber = int(fileNameStr.split('_')[0])
        #将获得的类别添加到hwLabels中
        hwLabels.append(classNumber)
        #将每一个文件的1x1024数据存储到trainingMat矩阵中
        trainingMat[i,:] = img2vector('trainingDigits/%s' % (fileNameStr))
    # 构建kNN分类器
    neigh =KNN(n_neighbors=3, algorithm='auto')
    # 拟合模型, trainingMat为训练矩阵,hwLabels为对应的标签
    neigh.fit(trainingMat, hwLabels)
    # 返回testDigits目录下的文件列表
    testFileList = listdir('testDigits')
    # 错误检测计数
    errorCount = 0.0
    # 测试数据的数量
    mTest = len(testFileList)
    # 从文件中解析出测试集的类别并进行 分类测试
    for i in range(mTest):
        # 获得文件的名字
        fileNameStr = testFileList[i]
        # 获得分类的数字
        classNumber = int(fileNameStr.split('_')[0])
        # 获得测试集的1x1024向量,用于训练
        vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
        # 获得预测结果
        classifierResult = neigh.predict(vectorUnderTest)
        print("分类返回结果为%d\t真实结果为%d" % (classifierResult, classNumber))
        if(classifierResult != classNumber):
            errorCount += 1.0
    print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount/mTest * 100))


"""
函数说明:main函数
"""
if __name__ =='__main__':
    handwritingClassTest()

trainingDigits是训练数据存储的文件夹,trainingDigits是测试数据存储的文件夹,文件存储形式为0_1.txt,0代表类别,1代表0类的第一章图片,以此类推。这个在博客中能找的,时间太久忘了原来的发布源了,找到了会附上源链接的。

4、Kd_tree介绍

(1)Kd_tree的构造

Kd_tree的构造:如数据为n维1
在这里插入图片描述

选取方差最大的属性的维度作为参考属性,按照此属性进行排序,如第一层使用X2作为划分依据,那么就按照X2排序找到中位数,由中位数和属性两个关键词确定的那个数就是根节点,这一层的划分属性记为X2。选定属性值大于中位数的放到右边,小于的放到左边,再对这两组数据进行迭代即可,可根据侧面标注此层的结点进行定位。元素个数为偶数时,中位数选大的。如使用下面一组数据构造Kd-tree。
在这里插入图片描述
基于KNN的手写字符识别_第1张图片
基于KNN的手写字符识别_第2张图片

(2)Kd_tree的查询

构造好之后,平面会被划分成许多块儿,如二维的就是矩形块儿,高维的就是超矩形,先查询目标点所在的超矩形块儿中的元素s0,计算出最短的距离,并以此最短距离做出一个球状空间(维度为3时),当查询完所在矩形块儿中的元素时,便向上寻找最小距离元素的父节点计算距离,之后查询此父节点之下,与s0同为兄弟节点的元素s1,若s1所在的超矩形与上文中所求的球状空间有重叠,则计算其中的节点与目标节点的距离。
基于KNN的手写字符识别_第3张图片
好了,KNN的相关问题介绍到这里了,新手上路难免有些错误,如有错误,还请大家多多包涵哈~


  1. 李航. 统计学习方法[M].北京:清华大学出版社,2012.3: 37-45 ↩︎

你可能感兴趣的:(机器学习)