简单地说,K-近邻算法(K-Nearest-Neighbors Classification)采用测量不同特征值之间的距离方法进行分类。
不要小看了这个K值选择问题,因为它对K近邻算法的结果会产生重大影响。如李航博士的一书「统计学习方法」上所说:
这里有4组数据,且(1,1.1)和(1,1)定义为A类,(0,0)和(0,0.1)为B类。下面对(0.5,0.5)进行分类,判断其为A、B哪一类。
算法过程:
(1)计算已知类别数据集中的点与当前点之间的距离;
(2)按照距离递增次序排序;
(3)选取与当前点距离最小的K个点;
(4)确定前K个点所在类别的出现频率;
(5)返回前K个点出现频率最高的类别作为当前点的预测分类。
具体代码:
from numpy import *
import operator
from os import listdir
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5 #计算距离
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1#选择距离最小的k个点
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]#排序
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group, labels
group,labels=createDataSet()
classify0([0.5,0.5],group,labels,3)
输出:
'B'
from sklearn import neighbors#导入包含KNN算法模块
from sklearn import datasets#导入数据集模块
knn = neighbors.KNeighborsClassifier()#调用分类器方法
iris = datasets.load_iris()#导入数据
print iris#分类规则:iris setosa、iris versicolor、iris virginica分别为用0、1、2表示
knn.fit(iris.data, iris.target)#建立模型
predictedLabel = knn.predict([[0.1, 0.2, 0.3, 0.4]])#预测新的对象属于哪一类
print predictedLabel
结果:
[0]
以上就是如何使用python里面的sklearn库来进行KNN算法的调用。
接下来介绍适合通过自己写程序来实现KNN的算法。
基本步骤:
import csv#用于读取数据
import random
import math
import operator
#导入数据
def loadDataset(filename, split, trainingSet=[] , testSet=[]):#加载数据集
with open(filename, 'rb') as csvfile:#将filename导入为csv格式的文件。(‘rb’读写模式)
lines = csv.reader(csvfile)#读取文件行数
dataset = list(lines)#转化为list的数据结构
for x in range(len(dataset)-1):
for y in range(4):
dataset[x][y] = float(dataset[x][y])
if random.random() < split:#将数据分为两部分,分别加到训练集和测试集中
trainingSet.append(dataset[x])
else:
testSet.append(dataset[x])
#计算距离
def euclideanDistance(instance1, instance2, length):#传入两个实例及维度
distance = 0
for x in range(length):#所有维度距离的平方和
distance += pow((instance1[x] - instance2[x]), 2)
return math.sqrt(distance)
#返回最近的K个label
def getNeighbors(trainingSet, testInstance, k):#testInstance测试集中的一个数据
distances = []#定义一个空的容器
length = len(testInstance)-1
for x in range(len(trainingSet)):#计算测试集(一个)到每一个训练集的距离
dist = euclideanDistance(testInstance, trainingSet[x], length)
distances.append((trainingSet[x], dist))#将所有的距离放在定义好的空容器diastances
distances.sort(key=operator.itemgetter(1))#距离从小到大排序
neighbors = []
for x in range(k):
neighbors.append(distances[x][0])
return neighbors#返回最近的k个邻居
#对邻居进行分类,找出类别最多的
def getResponse(neighbors):
classVotes = {}
for x in range(len(neighbors)):
response = neighbors[x][-1]
if response in classVotes:
classVotes[response] += 1
else:
classVotes[response] = 1
sortedVotes = sorted(classVotes.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedVotes[0][0]
#计算正确率
def getAccuracy(testSet, predictions):
correct = 0
for x in range(len(testSet)):
if testSet[x][-1] == predictions[x]:
correct += 1
return (correct/float(len(testSet))) * 100.0
def main():
# prepare data
trainingSet=[]#创建两个空的测试集和训练集
testSet=[]
split = 0.67#将2/3的数据划分为训练集,1/3划分为测试集
loadDataset(r'/home/duxu/exercise/iris.csv', split, trainingSet, testSet)
print 'Train set: ' + repr(len(trainingSet))
print 'Test set: ' + repr(len(testSet))
# generate predictions
predictions=[]
k = 3
for x in range(len(testSet)):
neighbors = getNeighbors(trainingSet, testSet[x], k)
result = getResponse(neighbors)
predictions.append(result)
print('> predicted=' + repr(result) + ', actual=' + repr(testSet[x][-1]))
accuracy = getAccuracy(testSet, predictions)
print('Accuracy: ' + repr(accuracy) + '%')
main()
输出:
Train set: 100
Test set: 49
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-setosa', actual='Iris-setosa'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-virginica', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-versicolor', actual='Iris-versicolor'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
> predicted='Iris-virginica', actual='Iris-virginica'
Accuracy: 97.95918367346938%
从结果看出:训练集有100个实例,测试集有50个实例;接着打印出来了测试集的预测结果和实际分类;最后计算出了预测的正确率约为98%,比较理想。