k近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法。
给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。
输入:训练数据集
k近邻法的模型对应于特征空间的划分。模型由三个基本要素–距离度量、k值的选择和分类决策规则决定。
当上述三个要素确定后,对于任何一个新的输入实例,所属的类唯一地确定。
特征空间中,对每个训练实例点 xi x i ,距离该点比其他店更近的所有点组成的一个区域,叫做单元(cell)。每个训练实例点拥有一个单元,所有训练实例点的单元构成对特征空间的一个划分。最近邻法将实例 xi x i 的类 yi y i 作为其单元中所有点的类标记(class label)。则每个单元的实例点的类别是确定的。
特征空间中两个实例点的距离是这两个实例点相似度的反映。k近邻模型的特征空间一般是n维实数向量空间 Rn R n 。使用的距离是欧式距离,也可以是其他距离,比如更一般的 LP L P 距离( LP L P distance)或Minkowski距离。
设特征空间X是n维实数向量空间 Rn R n , xi,xj∈X,xi=(x(1)i,x(2)i,…,x(n)i)T x i , x j ∈ X , x i = ( x i ( 1 ) , x i ( 2 ) , … , x i ( n ) ) T ,
xj=(x(1)j,x(2)j,...,x(n)j)T x j = ( x j ( 1 ) , x j ( 2 ) , . . . , x j ( n ) ) T , xi,xj的LP距离定义为: x i , x j 的 L P 距 离 定 义 为 :
k值的选择会对k近邻法的结果产生重大影响。
如果选择较小的k值,相当于用较小的邻域中的训练实例进行预测,学习的近似误差(approximation error)会减小,只有与输入实例较近的(相似的)训练实例才会对预测结果起作用。但缺点是学习的估计误差(estimation error)会增大,预测结果会对近邻的实例点非常敏感。k值的减小意味着整体模型变得负责,容易发生过拟合。
如果k值较大,相当于用较大邻域里的训练实例进行预测。优点是可以减少学习的估计误差,但缺点就是会增大近似误差。这是与输入实例较远的(不太相似)的训练实例也会对预测起作用,是预测发生错误。k值的增大意味着模型变得更简单。
如果k=N,则将输入实例预测为训练实例中最多的类。即模型过于简单,完全忽略了训练实例中的大量有用信息,不可取。
实际应用中,k值一般去一个比较小的数值。通常常采用交叉验证法(将原始数据(dataset)进行分组,一部分做为训练集(train set),另一部分做为验证集(validation set or test set),首先用训练集对分类器进行训练,再利用验证集来测试训练得到的模型(model),以此来做为评价分类器的性能指标)来选取最优值。
多用是多数表决,即由输入实例的k个近邻的训练实例中的多数类决定输入实例的类。多数表决的规则等价于经验风险最小化。
实现的过程中,主要的问题是如何对训练数据进行快速k近邻搜索。这在特征空间的维数大,及训练数据容量大时尤其必要。
最简单的方法是线性扫描(linear scan)。需要计算输入实例与每一个训练实例的距离。当训练集很大时,计算非常耗时,不可取。
为了改善,可以使用特殊的结构存储训练数据,比如kd树(kd tree)。
例:给定一个二维空间的数据集:
输入:已构造的kd树,目标点x;
输出:x的最近邻
kd树更适用于训练实例远大于空间维数时的k近邻搜索。当空间维数接近训练实例数时,效率会迅速下降,几乎接近线性扫描。
import numpy as np
import operator
def Dataset():
np.random.seed(13)
dataList=np.random.randint(1,10,8)
print('dataList',dataList)
data=np.array(dataList).reshape(4,2)
print('data',data)
lables=['A','B','A','B']
return data,lables
def classfy(target,dataset,labels,k):
dataSize=dataset.shape[0]
#compute Euclidean distance=sqrt(sum of all the difference between tartget and dataSet)
minus=np.tile(target,(dataSize,1))-dataset
temp=minus**2
temp1=temp.sum(axis=1) # sum of each row
distance=temp1**0.5
sortedDistIdx=distance.argsort()# return the indcies of sorted ele,emts
count={}
#count labels
for i in range(k):
theLabel=labels[sortedDistIdx[i]]
print('label={},i={}'.format(theLabel,i))
count[theLabel]=count.get(theLabel,0)+1
sortedCount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
return sortedCount[0][0]
data,label=Dataset()
target=[3,2]
className=classfy(target,data,label,3)
print('target is class:',className)