《机器学习实战》代码记录--knn--手写数字识别

该书代码及数据http://www.manning-source.com/books/pharrington/MLiA_SourceCode.zip

文件目录及样本数据:

testDigits目录下为测试数据,trainingDigits目录下为训练数据,文件名形如[0-9]_[0-200].txt,即有0至9的各200个左右不同的样本,例如9_9.txt样本内容如下:

《机器学习实战》代码记录--knn--手写数字识别_第1张图片

问题描述:

对testDigits下的样本进行分类并统计错误率


输出样例:

《机器学习实战》代码记录--knn--手写数字识别_第2张图片

代码(knn.py):

from numpy import *
import operator
import sys
from numpy import array
import os

def img2vector(filename):
        returnVect=zeros((1,1024))
        fr=open(filename)
        for i in range(32):
                lineStr=fr.readline()
                for j in range(32):
                        returnVect[0,32*i+j]=int(lineStr[j])
        return returnVect

def classify(inX,dataSet,labels,k):
        dataSetSize=dataSet.shape[0]
        diffMat=tile(inX,(dataSetSize,1))-dataSet
        sqDiffMat=diffMat**2
        sqDistances=sqDiffMat.sum(axis=1)
        distances=sqDistances**0.5
        sortedDistIndices=distances.argsort()
        classCount={}
        for i in range(k):
                voteIlabel=labels[sortedDistIndices[i]]
                classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
        sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]

def handwritingClassTest():
        k=3
        hwLabels=[]
        trainingFileList=os.listdir('trainingDigits')
        m=len(trainingFileList)
        trainingMat=zeros((m,1024))
        for i in range(m):
                fileNameStr=trainingFileList[i]
                fileStr=fileNameStr.split('.')[0]
                classNumStr=int(fileStr.split('_')[0])
                hwLabels.append(classNumStr)
                trainingMat[i:,]=img2vector('trainingDigits/%s'%fileNameStr)
        testFileList=os.listdir('testDigits')
        errorCount=0.0
        mTest=len(testFileList)
        for i in range(mTest):
                fileNameStr=testFileList[i]
                fileStr=fileNameStr.split('.')[0]
                classNumStr=int(fileStr.split('_')[0])
                vectorUnderTest=img2vector('trainingDigits/%s'%fileNameStr)
                classifierResult=classify(vectorUnderTest,trainingMat,hwLabels,k)
                print "the classifier came back with :%d,the real answer is :%d"%(classifierResult,classNumStr)
                if (classifierResult!=classNumStr):
                        errorCount+=1.0
        print "\nthe total number of error is : %d"%errorCount
        print "\nthe total error rate is : %f"%(errorCount/float(mTest))




if __name__=='__main__':
        handwritingClassTest()




































你可能感兴趣的:(python,knn,机器学习实战)