该书代码及数据http://www.manning-source.com/books/pharrington/MLiA_SourceCode.zip
文件目录及样本数据:
testDigits目录下为测试数据,trainingDigits目录下为训练数据,文件名形如[0-9]_[0-200].txt,即有0至9的各200个左右不同的样本,例如9_9.txt样本内容如下:
问题描述:
对testDigits下的样本进行分类并统计错误率
输出样例:
代码(knn.py):
from numpy import * import operator import sys from numpy import array import os def img2vector(filename): returnVect=zeros((1,1024)) fr=open(filename) for i in range(32): lineStr=fr.readline() for j in range(32): returnVect[0,32*i+j]=int(lineStr[j]) return returnVect def classify(inX,dataSet,labels,k): dataSetSize=dataSet.shape[0] diffMat=tile(inX,(dataSetSize,1))-dataSet sqDiffMat=diffMat**2 sqDistances=sqDiffMat.sum(axis=1) distances=sqDistances**0.5 sortedDistIndices=distances.argsort() classCount={} for i in range(k): voteIlabel=labels[sortedDistIndices[i]] classCount[voteIlabel]=classCount.get(voteIlabel,0)+1 sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] def handwritingClassTest(): k=3 hwLabels=[] trainingFileList=os.listdir('trainingDigits') m=len(trainingFileList) trainingMat=zeros((m,1024)) for i in range(m): fileNameStr=trainingFileList[i] fileStr=fileNameStr.split('.')[0] classNumStr=int(fileStr.split('_')[0]) hwLabels.append(classNumStr) trainingMat[i:,]=img2vector('trainingDigits/%s'%fileNameStr) testFileList=os.listdir('testDigits') errorCount=0.0 mTest=len(testFileList) for i in range(mTest): fileNameStr=testFileList[i] fileStr=fileNameStr.split('.')[0] classNumStr=int(fileStr.split('_')[0]) vectorUnderTest=img2vector('trainingDigits/%s'%fileNameStr) classifierResult=classify(vectorUnderTest,trainingMat,hwLabels,k) print "the classifier came back with :%d,the real answer is :%d"%(classifierResult,classNumStr) if (classifierResult!=classNumStr): errorCount+=1.0 print "\nthe total number of error is : %d"%errorCount print "\nthe total error rate is : %f"%(errorCount/float(mTest)) if __name__=='__main__': handwritingClassTest()