knn算法python实现

K最近邻(k-Nearest Neighbor,KNN)分类算法思路

        如果一个样本在特征空间中的k个最相似的样本中的大多数属于某一个类别,则该样本也属于这个类别。

KNN算法还可用于回归。方法是通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成正比(组合函数)。

优点:

(1)简单、易实现、易理解、无需参数估计及训练;

(2)适用于对稀有时间进行分类;

(3)特别适用于多分类问题(multi-modal,分类对象具有多个类别标签),比SVM表现要好。

缺点:(分类)

(1)当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。 改进方法:采用权值,与该样本距离小的邻居权值大;

(2)计算量较大,对每一个待分类的样本都要计算它与全体已知样本的距离进行排序继而求得K个最近邻点。目前常用的改进方法,事先对已知样本点进行剪辑,去除对分类作用不大的样本。该改进算法比较适用于容量大的类域,而容量较小的类域容易产生误分。

伪代码:

(1)计算已知类别数据集中的点与当前点之间的距离;
(2)按照距离递增次序排序;
(3)选取与当前点距离最小的k个点;
(4)确定前k个点所在的类别的出现频率;
(5)返回前k个点出现频率最高的类别作为当前点的预测分类

Python实现(未调用operator库):

(1)knn.py

from numpy import *
import pdb

def getDist(dataSet,Sample):
    return sum(power(dataSet - Sample,2))
    
def sortDist(dist,labels):
    n = shape(dist)[0]
    for i in range(1,n):
        for j in range(n-i):
            if (dist[j] > dist[j+1]):
                temp1 = dist[j]
                dist[j] = dist[j+1]
                dist[j+1] = temp1;
                temp2 = labels[j]
                labels[j] = labels[j+1]
                labels[j+1] = temp2                 
    return dist,labels
    
def countLabels(labels,k):
    labelCount = zeros(2)#assume that the amount of categories is 2
    for i in range(k):
        labelCount[labels[i]-1] = labelCount[labels[i]-1] + 1
    maxcount = -1
    for i in range(2):
        if(labelCount[i] > maxcount):
            maxcount = labelCount[i]
            label = i
    return label + 1 
           
def doKnn(Sample,dataSet,labels,k):
    n,d = dataSet.shape
    dist = zeros(n)
    for i in range(n):
        dist[i] = getDist(dataSet[i],Sample)
    #sort
    #pdb.set_trace()
    dist,labels = sortDist(dist,labels)
    #compute the count of each label
    #pdb.set_trace()
    label = countLabels(labels,k)
    print label
    print "Done!"
(2)test_Knn:

from numpy import *
import knn
import pdb
#read data
dataSet = mat([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = [1,1,2,2]
knn.doKnn([1.0,1.1],dataSet,labels,3)

(2)Python实现(利用operator库)

knnOperator.py

from numpy import *
import pdb
import operator

def knnOperator(Sample,dataSet,labels,k):
    n = dataSet.shape[0]
    sddist = tile(Sample,(n,1)) - dataSet
    sddist = sddist**2
    dist = sddist.sum(axis =1)
    dist = dist**2
    sortDistIndicies = dist.argsort()
    classCount = {}
    for i in range(k):
        voteLabel = labels[sortDistIndicies[i]]
        classCount[voteLabel] = classCount.get(voteLabel,0) + 1
        sortCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True)
    print "Done!"
    return sortCount[0][0]

testKnnOperator.py

from numpy import *
import knnOperator
#read data
d = 3  
filename = "datingtestset.txt"
fr = open(filename)
lines = fr.readlines()
n = len(lines)
dataSet = zeros((n,d))
labels = []
i = 0
for line in lines:
    line = line.strip().split('\t')
    dataSet[i,1:d] = line[1:d]
    labels.append(line[-1])
    i = i + 1

Sample = dataSet[0,:]
classifiedLabel = knnOperator.knnOperator(Sample,dataSet,labels,10) 
print labels[0]
print classifiedLabel   
注:datingtestset.txt文件保存的是测试数据,每行为一个样本,前三维为特征属性,最后一维为其类别标签。

程序在处理数据的时候,是将特征属性和类别标签分类处理。

数据归一化:当不同特征的取值范围不同或者差别很大时,进行数值归一化,方法:

newValue = (oldValue - min) / (max - min)

from numpy import *
import knnOperator

def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    range = maxVals - minVals
    normdataSet = zeros(shape(dataSet))
    n = dataSet.shape[0]
    normdataSet = (dataSet - tile(minVals,(n,1)))/tile(range,(n,1))
    return  normdataSet
    
#read data
d = 3  
filename = "datingtestset.txt"
fr = open(filename)
lines = fr.readlines()
n = len(lines)
dataSet = zeros((n,d))
labels = []
i = 0
for line in lines:
    line = line.strip().split('\t')
    dataSet[i,0:d] = line[0:d]
    labels.append(line[-1])
    i = i + 1
dataSet = autoNorm(dataSet)
Sample = dataSet[0,:]
classifiedLabel = knnOperator.knnOperator(Sample,dataSet,labels,10) 
print labels[0]
print classifiedLabel   









你可能感兴趣的:(机器学习,Python,C++)