机器学习(二)——knn算法深度解析

一:引言

        上一篇中对knn算法进行了相对的了解,并通过一个十分简单的例子大致了解了knn算法的工作原理和工作流程。本篇将继续对knn算法进行深度的解析,帮助理解knn算法

二:算法解析

        1.先简单回顾一下大致的步骤,knn算法可以大致为三个步骤,第一步就是确定k值,就是找与测试点前k个相似的数据,一般取值为不大于20的整数。第二步就是求测试点与样本点的距离。第三部就是按照样本规则来确定测试点属于哪一类。下面是网上找的一张图片。

        2.k的取值分析

        下面是网上找的一张图片。当k=3时,确定的样本点在图中圆圈的范围内,这时候就自然而然按照统计的规则判断测试点为蓝三角形。

        同样k的取值不同也会让结果产生误差,下图是网上找到一张有代表性的图

        机器学习(二)——knn算法深度解析_第1张图片

        其中当k为3的时候,按照规则测试点应该划分到红三角的范围,但当k=5的时候测试点又划分到蓝方块的范围内。因此k值的取值应该进行思考。当范围内两种或以上类型的数据点数量相同时,就可以按照距离来划分,下面来讲解距离测试的方法。 

        3.距离测试

        对于knn算法中距离的测试一般用下面的测试方法,测试点用(x,y)表示

        

        这个公式就是欧氏距离公式,将测试点与目标点的距离求出,如果样本点较多就多加上几个点的距离, 可以对数据按从小到大的次序排序,然后确定测试点的主要分类。在这里值得一提的是,对于样本点knn算法是不需要进行训练的,只需要对测试点和样本点进行距离求和,所以训练过程为O(1),测试过程对于n组,每组m个为O(mn)

3.实验数据分析

        首先编写一段函数img2vector,将图像转换成向量,创建1*1021的NumPy数组,将32*32的图像字符值存在数组中。

def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

        然后就是识别体的测试代码


def handwritingClassTest():
    hwLabels = []
    trainingFileList = os.listdir("digits/trainingDigits")
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('digits/trainingDigits/%s' %(fileNameStr))
    #读取测试数据
    testFileList = os.listdir('digits/testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    errfile = []
    #循环测试每个测试数据文件
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('digits/testDigits/%s' %(fileNameStr))
        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels ,3)    
        #输出k-近邻算法分类结果和真实的分类
        print('the classifier came back with: %d,the real answer is:%d' %(classifierResult,classNumStr))
        #判断k-近邻算法是否准确
        if(classifierResult != classNumStr):
            errorCount +=1.0
            errfile.append(fileNameStr)
    print('\n the total number of errors is: %d' %(errorCount))         
    print('\n the total error rate is: %f' %(errorCount/float(mTest)))   

结果展示

在这里插入图片描述

        该函数的输出结果依赖于机器速度,加载数据集也需要很多数据,具有一定的错误率,改变k的值也会对算法的错误率产生影响。

你可能感兴趣的:(算法,自动驾驶,机器学习)