KNN(K Nearest Neighbors,K近邻 )算法是机器学习所有算法中理论最简单,最好理解的。KNN是一种基于实例的学习,通过计算新数据与训练数据特征值之间的距离,然后选取K(K>=1)个距离最近的邻居进行分类判断(投票法)或者回归。如果K=1,那么新数据被简单分配给其近邻的类。KNN算法算是监督学习还是无监督学习呢?首先来看一下监督学习和无监督学习的定义。对于监督学习,数据都有明确的label(分类针对离散分布,回归针对连续分布),根据机器学习产生的模型可以将新数据分到一个明确的类或得到一个预测值。对于非监督学习,数据没有label,机器学习出的模型是从数据中提取出来的pattern(提取决定性特征或者聚类等)。例如聚类是机器根据学习得到的模型来判断新数据“更像”哪些原数据集合。KNN算法用于分类时,每个训练数据都有明确的label,也可以明确的判断出新数据的label,KNN用于回归时也会根据邻居的值预测出一个明确的值,因此KNN属于监督学习。
一、 简单、有效。
二、 重新训练的代价较低(基本不需要训练)。
三、 计算时间和空间线性于训练集的规模(在一些场合不算太大),样本过大识别时间会很长。
四、 k值比较难以确定。
mnist是一个手写数字的库,包含数字从0-9,每个图像大小为32*32,详细介绍和数据下载见这里
用到了PIL,numpy这两个python库,没有安装的可以参照我的另外一篇博客去配置安装,这就不多说了
代码是我修改的大牛的原始代码生成的,参见下面的参考文献,我也已经上传CSDN,一份是大牛的原始代码,一份是新的
我们需要使用KNN算法去识别mnist手写数字,具体步骤如下:
首先需要将手写数字做成0 1串,将原图中黑色像素点变成1,白色为0,写成TXT文件;
python代码:
def img2vector(impath,savepath):
''' convert the image to an numpy array Black pixel set to 1,white pixel set to 0 '''
im = Image.open(impath)
im = im.transpose(Image.ROTATE_90)
im = im.transpose(Image.FLIP_TOP_BOTTOM)
rows = im.size[0]
cols = im.size[1]
imBinary = zeros((rows,cols))
for row in range(0,rows):
for col in range(0,cols):
imPixel = im.getpixel((row,col))[0:3]
if imPixel == (0,0,0):
imBinary[row,col] = 0
#save temp txt like 1_5.txt whiich represent the class is 1 and the index is 5
fp = open(savepath,'w')
for x in range(0,imBinary.shape[0]):
for y in range(0,imBinary.shape[1]):
fp.write(str(int(imBinary[x,y])))
fp.write('\n')
fp.close()
结果大概像这样:
将所有的TXT文件中的0 1串变成行向量
python代码:
def vectorOneLine(filename):
rows = 32
cols = 32
imgVector = zeros((1, rows * cols))
fileIn = open(filename)
for row in xrange(rows):
lineStr = fileIn.readline()
for col in xrange(cols):
imgVector[0, row * 32 + col] = int(lineStr[col])
return imgVector
KNN识别
python代码:
def kNNClassify(testImput, TrainingDataSet, TrainingLabels, k):
numSamples = dataSet.shape[0] # shape[0] stands for the num of row
#calculate the Euclidean distance
diff = tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise
squaredDiff = diff ** 2 # squared for the subtract
squaredDist = sum(squaredDiff, axis = 1) # sum is performed by row
distance = squaredDist ** 0.5
#sort the distance vector
sortedDistIndices = argsort(distance)
#choose k elements
classCount = {} # define a dictionary (can be append element)
for i in xrange(k):
voteLabel = labels[sortedDistIndices[i]]
#initial the dict
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
#vote the label as final return
maxCount = 0
for key, value in classCount.items():
if value > maxCount:
maxCount = value
maxIndex = key
return maxIndex
*识别结果*
参考文献:
[1]大牛的博客:http://blog.csdn.net/zouxy09/article/details/16955347
[2]matlab 实现KNN: http://blog.csdn.net/rk2900/article/details/9080821
[3]分类算法的优缺点:http://bbs.pinggu.org/thread-2604496-1-1.html
[4]代码下载地址:http://download.csdn.net/detail/gavin__zhou/9208821
http://download.csdn.net/detail/gavin__zhou/9208827