1 目标
通过KNN算法对手写数字图像进行识别。
2 思路
(1) 删除train文件夹中标签与图像明显错误的样本;
(2) 样本格式多种多样,将样本统一处理成 28*28 位深度为8;
(3) 将图像样本转换为txt格式;
(4) 将所用样本转换为数组矩阵,形成训练样本;
(5) 提取文件名的首字符,形成label;
(6) 基于KNN算法,在训练集上训练,并保存训练好的模型“knn.pkl”;
(7) 基于训练好的模型,在测试集上测试
3 理论基础
3.1 KNN 算法
KNN作为一种有监督分类算法,是最简单的机器学习算法之一,顾名思义,其算法主体思想就是根据距离相近的邻居类别,来判定自己的所属类别。算法的前提是需要有一个已被标记类别的训练数据集,具体的计算步骤分为一下三步:
A、计算测试对象与训练集中所有对象的距离,可以是欧式距离、余弦距离等,比较常用的是较为简单的欧式距离;
B、找出上步计算的距离中最近的K个对象,作为测试对象的邻居;
C、找出K个对象中出现频率最高的对象,其所属的类别就是该测试对象所属的类别。
4 实践过程
4.1 预处理
4.1.1 去除文件名与图片明显不对应的图片
如果在训练样本中有明显标记错误的样本,会对结果产生比较大的干扰,所以要先把异样样本剔除,经剔除后,train文件中剩余7643个样本。部分异常样本如下:
图1 部分异常样本
4.1.2 统一图像格式
图像格式多种多样,进行预处理,统一处理为28*28像素,单通道8位灰度(0-255)的图像格式,编写代码将原始数据转为可训练数据。
代码在 ResizeImg.py中:
# This Python file uses the following encoding: utf-8
from PIL import Image
import os
def file_name(file_dir):
for root, dirs, files in os.walk(file_dir):
count = 1
print(files)
for i in files:
im = Image.open('train/'+i)
#filename = os.listdir(file_dir)
#print(filename)
out = im.resize((28,28))#Output Image 28*28pixels
out = out.convert("L") #Convert Image to 8 gray
out.save('data/ResizedImg/'+ i,'PNG')
count += 1
print(i)
if __name__ == "__main__":
file_name('train/')#Input Image
完成后,将预处理完成的图片存储到data/ResizedImg文件夹下
4.2 将图片转换为数组矩阵
训练数据集与测试数据集都是标准化后的数组矩阵,而我们的试验对象是手写体数字图片,首先需要将图片进行一下数据化处理。
处理思路为:
通过提取像素值,并将黑色的像素用“1”表示,将白色的像素用“0”表示,再保存为txt格式的文档即可,保存在 data/txtdata 文件夹下。
图2 转换后的txt文件夹
处理后,txt文件结果如下:
图3 处理后的数字 “2”和“9”样本
然后,将所有的txt文件拼接成一个数组,保存在trainarr.txt中,作为最终的训练样本输入。
def data2array(fname):
arr = []
fh = open(fname)
for i in range(0,28):
thisline = fh.readline()
for j in range(0,28):
arr.append(int(thisline[j]))
return arr
4.3 制作label
通过提起文件名中的首字符,制作label,形成labels.txt
label = “图片路径”.split(“.”)[0].split(“-”)[0]
labels.txt中结果如下:
4.4 基于KNN算法进行训练
基于sklearn中的KNN 算法进行手写数字识别,并基于交叉检验进行超参数整定,选择最优的邻居数k,并将训练好的模型保存为knn.pkl
经训练,k=3时,模型效果最优,但是模型正确率仅为 31.2%。
from maketraindata import traindata
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import pandas as pd
from sklearn.model_selection import GridSearchCV
import joblib
trainarr,labels = traindata()
# 特征工程
# 建立模型
knn = KNeighborsClassifier() # 在这边不设置超参数,将超参数在网格搜索和交叉验证中设置
# 设置网格搜索和交叉验证, cv(cross validation 交叉验证)
gc= GridSearchCV(estimator = knn, param_grid={'n_neighbors':[2,3,5],'weights':['uniform','distance']},cv=5)
# 模型训练
gc.fit(trainarr,labels)
print("交叉验证与网格搜索最好的结果",gc.best_estimator_)
print('交叉验证最好的训练分数',gc.best_score_)
# 保存模型
joblib.dump(gc,"data/knn.pkl")
4.5 在测试集上进行测试
在测试集上,调用训练好的模型进行预测,并将测试结果保存在test.txt文件中,程序test.py如下:
import joblib
from os import listdir
from PIL import Image
import numpy as np
import pandas as pd
import time
def test():
# 1 加载模型
KNN = joblib.load('data/knn.pkl')
pngpath = "test/"
pnglist = listdir(pngpath)
txtpath = "data/testtxt/"
labels = []
results = []
for i in pnglist:
#加载图像
im = Image.open(pngpath+i)
#图像预处理
out = im.resize((28, 28)) # Output Image 20*20pixels
out = out.convert("L")
#图像转换为txt
fh = open(txtpath + i + '.txt', "a") # 打开待保存的文档
trainpath = txtpath + i + '.txt'
for m in range(0, 28): # 此32为图片的像素高度,也可以用width = im.size()[1]提取
for n in range(0, 28): # 此32为图片的像素宽度,也可以用width = im.size()[0]提取
pix = im.getpixel((n, m)) # 提取像素
#print(pix)
# pixs = pix[0]+pix[1]+pix[2]
# pixs = pix[0] + pix[1]
if pix ==0: # RGB数值相加为0,表示黑色(0,0,0)
fh.write("1") # 黑色用1表示
else:
fh.write("0") # 白色用1表示
fh.write("\n") # 每一行的末尾输入换行符
fh.close()
# 图像转换为ndarry
#trainfile = open(txtpath + i + '.txt')
trainarr = np.zeros((1, 784))
label = i.split(".")[0].split("-")[0]
trainarr= data2array(trainpath) # 将训练数据写入0矩阵
trainarr = np.array(trainarr).reshape(1,-1) # 在最新版本的sklearn中,所有的数据都应该是二维矩阵 ,list不能使用reshape,需要将其转化为array,然后就可以使用reshape
#print(trainarr)
# 测试
result=KNN.predict(trainarr)
# 将result 写入本地文件
fh1 = open('data/test.txt', "a") # 打开待保存的文档
fh1.write(str(result)+'\n')
fh1.close()
print(label,result)
labels.append(int(label))
results.append(int(result))
return labels, results
def data2array(fname):
arr = []
fh = open(fname)
for i in range(0,28):
thisline = fh.readline()
for j in range(0,28):
arr.append(int(thisline[j]))
return arr
def text_save(filename, data):#filename为写入CSV文件的路径,data为要写入数据列表.
file = open(filename,'a')
for i in range(len(data)):
s = str(data[i]).replace('[','').replace(']','')#去除[],这两行按数据不同,可以选择
s = s.replace("'",'').replace(',','') +'\n' #去除单引号,逗号,每行末尾追加换行符
file.write(s)
file.close()
print("保存文件成功")
参考:[添加链接描述](https://zhuanlan.zhihu.com/p/51326751)