手写数字识别python_Python手写数字识别(KNN算法)

前言

手写数字识别作为机器学习中一个比较有趣的内容,结合 K210 芯片强大的 KPU 算力,可以蹦出很多好玩的点子。本次以通俗易懂的方式记录一波玩耍手写数字识别的案例。

目录

  1. 相关理论
  2. 基于 KNN 分类算法的模型搭建(代码展示)
  3. 图片导入与分类预测(代码展示)

1、相关理论

1.1 图像二值化

手写数字识别python_Python手写数字识别(KNN算法)_第1张图片

如上图:是一个红色的数字4,像素尺寸为 32*32,即一共有1024个像素点,每个像素点的色彩值范围是0-255,其中0代表黑色,255代表白色,故本文以128为阈值进行划分。

令单个像素点的色彩值为 x ,则:x<=128 时,取 x = 1,反之 x = 0

经过上述变换,整张图片的色彩值就只有 0,1 两种了,此过程就是图片的二值化过程。

1.2 二值化数据的存储

经过二值化的图片,仍然有1024个像素点,只是其值仅有0,1两种,那么,怎样将图片的信息和模型之间建立联系呢?

下图是一份包含多张手写图像信息的 EXCEL 表格。从第一行来看,其第一列代表图片中的数字,其它列为分别为1024个像素点的值。这样,每一行就包含了一张图片的信息。

手写数字识别python_Python手写数字识别(KNN算法)_第2张图片

所以后面需要做的工作就是训练模型,通过 1024 个像素点的信息推断出该行第一列所对应的信息。

当然啦,上图中的训练数据均有结果对应,用于训练模型和检验模型的有效性;而手写的数字只有1024个像素点信息,所以需要通过训练好的模型进行预测,推断出图片中对应的数字。

2、KNN 分类算法

KNN 分类算法就比较简单了。盯着下图,最中间的这个圆当中有2个三角形和一个方块,它们各属于一个阵营。突然新来了一个五角星,那怎么判断五角星是属于哪个阵营呢?

很明显,找出 K 个距离五角星最近的形状,比如 K=3,那么就选中了刚刚提到的这2个三角形和1个方块。其中,三角形比正方形多,显然三角形更牛逼,所以五角星跟它同阵营。这就是 KNN 分类的原理。

手写数字识别python_Python手写数字识别(KNN算法)_第3张图片

至于距离的计算,高中数学里的欧式距离公式就能搞定:

832c327bcc871a79d7e4f56448d29c82.png

推广到多维变量,则有:

d679a08933f19d41c83e8e11720955e8.png

3、图片导入及分类预测

3.1 导入 EXCEL 中的训练数据

import pandas as pd
df = pd.read_excel('data.xlsx')
x = df.drop(columns = '对应数字')  # 抛弃第一列的因变量数据,保留自变量数据
y = df['对应数字']                 # 第一列的因变量数据

3.2 训练 KNN 模型

from sklearn.model_selection import train_test_split
xtrain,xtest,ytrain,ytest = train_test_split(x,y,test_size = 0.2,random_state = 123)
from sklearn.neighbors import KNeighborsClassifier as KNN
knn = KNN(n_neighbors=5)           # 算法 K 参数设置 [1:8]
model = knn.fit(xtrain,ytrain)
print(model)

3.3 对预测模型进行评价

from sklearn.metrics import accuracy_score
y_pred = knn.predict(xtest)
score2 = knn.score(xtest,ytest)    # 计算 knn 分类预测得分
print('KNN预测模型评分为: ' + str(score2))

3.4 图片数字导入及模型验证

from PIL import Image
img = Image.open('数字6.png')
img = img.resize ((32,32))
img = img.convert('L')  # 转换为灰度图
# 图片二值化处理
import numpy as np
img_new = img.point (lambda x: 0 if x >128 else 1)   # 像素点控制,0为黑,255为白,所以取中间128作为阈值,对像素点二值化处理
arr = np.array(img_new)                              # arr 为 32*32 像素点矩阵
# 打印二值化图形
for i in range(arr.shape[0]):                        # shape[0]代表行数,1代表列数
    print(arr[i])
arr1 = arr.reshape(1,-1)                             # arr 转换为 1 行,(-1,1)为 1 列
result = knn.predict(arr1)
print('识别的数字为:' + str(result))

4、效果展示

4.1 KNN 模型评分情况

从下图可以看出, KNN 模型的预测精度 score 表现还是非常不错的!

16e3e9a9d8221aa1be2289f972d3c6ee.png

4.2 图片二值化处理后的像素数据

不难看出,1代表黑色部分,0代表白色部分;在像素矩阵中可以明显地体现 6 这个数字。

手写数字识别python_Python手写数字识别(KNN算法)_第4张图片

4.3 预测结果

64d887655163cfba344c3e61f79df48b.png

其实,从 0 到 9 均测试了一遍,只有数字 8 比较刁钻,好几次预测成了 9;最后发现是截图的方式有点小问题,纠正以后,所有的数字均能够精准识别。

5、结论

最后,接 4.3 涉及的问题,有一个注意事项:截图的时候尽可能小,让数字充满整张图片,这样才能够有效保证识别的准确度。

你可能感兴趣的:(手写数字识别python)