python3.5《机器学习实战》学习笔记(三):k近邻算法scikit-learn实战手写体识别

转载请注明作者和出处:http://blog.csdn.net/u013829973
系统版本:window 7 (64bit)
我的GitHub:https://github.com/weepon
python版本:python 3.5
IDE:Spyder (一个比较方便的办法是安装anaconda,那么Spyder和jupyter以及python几个常用的包都有了,甚至可以方便的安装TensorFlow等,安装方法链接)

在前面学习笔记(一)、(二)我们主要介绍了k近邻的基本原理及一步步实现,但是实际上用的时候,不用自己从头编写,我们只要使用scikit-learn中的k近邻函数就可以了,看下图的各种knn算法实现:链接在此

python3.5《机器学习实战》学习笔记(三):k近邻算法scikit-learn实战手写体识别_第1张图片

要想实现上篇文章中手写体数字识别,只需调用KNeighborsClassifier函数,再指定相应参数就可以了
python3.5《机器学习实战》学习笔记(三):k近邻算法scikit-learn实战手写体识别_第2张图片

各种参数:

  • n_neighbors:默认为5,kNN的k的值。

  • weights:默认是uniform。uniform是均等的权重,就说所有的邻近点的权重都是相等的。distance是不均等的权重,距离近的点比距离远的点的影响大。用户自定义的函数,接收距离的数组,返回一组维数相同的权重。

  • algorithm:快速k近邻搜索算法,默认参数为auto。用户可以指定搜索算法ball_tree、kd_tree、brute方法进行搜索,brute是蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时。kd_tree,构造kd树存储数据以便对其进行快速检索的树形数据结构,kd树也就是数据结构中的二叉树。以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高。ball tree是为了克服kd树高纬失效而发明的,其构造过程是以质心C和半径r分割样本空间,每个节点是一个超球体。

  • leaf_size:默认是30,这个是构造的kd树和ball树的大小。这个值的设置会影响树构建的速度和搜索速度,同样也影响着存储树所需的内存大小。需要根据问题的性质选择最优的大小。

  • metric:用于树的距离度量。默认的度量是闵可夫斯基度量,p = 2等价于标准的欧几里得度量。

  • p:距离度量选择。除了欧式距离,还有曼哈顿距离等这个参数默认为2,也就是默认使用欧式距离公式进行距离度量。也可以设置为1,使用曼哈顿距离公式进行距离度量。

  • metric_params:距离公式的其他关键参数,默认None。

  • n_jobs:并行处理设置。默认为1,临近点搜索并行工作数。如果为-1,那么CPU的所有cores都用于并行工作。

手写题数字识别的完整python代码:

'''
Created on Sep 10, 2017

kNN: k近邻(k Nearest Neighbors)
实战:手写识别系统

author:weepon
'''
import numpy as np
import operator
from os import listdir
from sklearn.neighbors import KNeighborsClassifier

'''
函数功能:将32x32的二进制图像转换为1x1024向量

Input:     filename :文件名
Output:    二进制图像的1x1024向量

'''
def img2vector(filename):
    returnVect = np.zeros((1,1024))            #创建空numpy数组
    fr = open(filename)                         #打开文件
    for i in range(32):
        lineStr = fr.readline()                #读取每一行内容
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])#将每行前32个字符值存储在numpy数组中
    return returnVect

'''
函数功能:手写数字分类测试
'''    
def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')    #加载训练集
    m = len(trainingFileList)                     #计算文件夹下文件的个数,因为每一个文件是一个手写体数字
    trainingMat = np.zeros((m,1024))            #初始化训练向量矩阵
    for i in range(m):
        fileNameStr = trainingFileList[i]        #获取文件名
        fileStr = fileNameStr.split('.')[0]     #从文件名中解析出分类的数字
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    #构建kNN分类器
    neigh = KNeighborsClassifier(n_neighbors = 3, algorithm = 'auto')
    #拟合模型, trainingMat为测试矩阵,hwLabels为对应的标签
    neigh.fit(trainingMat, hwLabels)
    testFileList = listdir('testDigits')        #加载测试集
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]          #从文件名中解析出测试样本的类别
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = neigh.predict(vectorUnderTest) #开始分类
        print ('the classifier came back with: %d, the real answer is: %d' % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0            #计算分错的样本数
    print ('\nthe total number of errors is: %d' % errorCount)
    print ('\nthe total error rate is: %f' % (errorCount/float(mTest)))

'''
主函数
'''    
if __name__ == '__main__':
    handwritingClassTest()

运行结果:

python3.5《机器学习实战》学习笔记(三):k近邻算法scikit-learn实战手写体识别_第3张图片

你可能感兴趣的:(机器学习,机器学习,k近邻算法,sklearn实现,python3-5)