k近邻算法是一种基本的分类算法,它的思想非常的简单直观,即一个样本的类别应该和训练数据集中和它距离最近的k个样本中多数样本所属的类别相同,因此,k近邻法分类时没有显式的学习过程。k近邻法的模型实际是一种对特征空间的划分,模型由距离度量、k值的选择和决策规则决定,对于决策规则,我们一般使用多数表决的原则,因此模型的表现主要由距离度量和k值决定。
特征空间中两点间的距离可以看做是两个样本相似度的一种表现,在k近邻法中距离我们一般使用欧氏距离,它指的是两点间的真实距离,定义如下:
k值的选择对于模型的表现具有非常重要的影响。具体来说,如果选择一个较小的k值,就是用一个较小的邻域中的训练实例进行预测,这种情况下“学习”的近似误差会很小,但是“学习”的估计误差会增大。因为只使用一个较小的邻域去预测,训练集中的噪声点将会对结果造成很大的影响,考虑一个极端情况,当k=1时,预测样本的类别就等于训练集中与它距离最近的样本的类别,如果该点刚好是噪声点,那么预测将会发生错误。也就是说,k越小,模型越复杂,也就越容易发生过拟合。
相反的,当我们增大k值时,模型会变得简单,相应的容易出现欠拟合。考虑当k=n是,训练集中所有的点均是输入实例的邻域,对于任何输入实例其分类均为训练集中多数样本所属的类别。
因此在实际应用中,k通常取一个比较小的值,但同时也要通过交叉验证等方式确定k的具体取值。
对于k近邻算法,因为需要寻找与样本距离最近的k个点,因此一种简单直接的方法就是计算输入实例与所有样本间的距离,当训练集很大时,其时间复杂度O(n)会很大,因此,为了提高k近邻搜索的效率,可以考虑使用kd树的方法。kd树是一个二叉树,表示对k维空间(这里的k与上文提到的k近邻法中的k意义不同)的划分。构建kd树的过程相当于不断用垂直于坐标轴的超平面将k维空间切分,构成一系列的k维超矩形区域。kd树的每个节点对应于一个k维超矩形区域。以下是一颗kd树构建的过程。
这是一个由原始数据集生成kd树的一个简单例子(位于页面中部的右侧),可以帮助你更好的理解kd树的生成过程。
给定一个目标点,搜索其最近邻,首先找到包含目标点的叶节点,然后从该叶节点出发,依次退回到其父节点,不断查找是否存在比当前最近点更近的点,直到退回到根节点时终止,获得目标点的最近邻点。如果按照流程可描述如下:
1. 从根节点出发,若目标点x当前维的坐标小于切分点的坐标,则移动到左子节点,反之则移动到右子节点,直到移动到最后一层叶节点。
2. 以此叶结点为“当前最近点”
3. 递归的向上回退,在每个节点进行如下的操作:
a.如果该节点保存的实例点距离比当前最近点更小,则该点作为新的“当前最近点”
b.检查“当前最近点”的父节点的另一子节点对应的区域是否存在更近的点,如果存在,则移动到该点,接着,递归地进行最近邻搜索。如果不存在,则继续向上回退
4. 当回到根节点时,搜索结束,获得最近邻点
以上分析的是k=1是的情况,当k>1时,在搜索时“当前最近点”中保存的点个数<=k的即可。
相比于线性扫描,kd树搜索的平均计算复杂度为O(logN).但是当样本空间维数接近样本数时,它的效率会迅速降低,并接近线性扫描的速度。因此kd树搜索适合用在训练实例数远大于样本空间维数的情况。
以下关于kd树与k近邻法的Python实现来自于stefankoegl,在其代码的基础上,添加了相应的中文注释。
class Node(object):
"""初始化一个节点"""
def __init__(self, data=None, left=None, right=None):
self.data = data
self.left = left
self.right = right
class KDNode(Node):
"""初始化一个包含kd树数据和方法的节点"""
def __init__(self, data=None, left = None,right =None,axis = None,
sel_axis=None,dimensions=None):
"""为KD树创建一个新的节点
如果该节点在树中被使用,axis和sel_axis必须被提供。
sel_axis(axis)在创建当前节点的子节点中将被使用,
输入为父节点的axis,输出为子节点的axis"""
super(KDNode,self).__init__(data,left,right)
self.axis = axis
self.sel_axis = sel_axis
self.dimensions = dimensions
def create(point_list=None, dimensions=None, axis=0,sel_axis=None):
"""从一个列表输入中创建一个kd树
列表中的所有点必须有相同的维度。
如果输入的point_list为空,一颗空树将被创建,这时必须提供dimensions的值
如果point_list和dimensions都提供了,那么必须保证前者维度为dimensions
axis表示根节点切分数据的位置,sel_axis(axis)在创建子节点时将被使用,
它将返回子节点的axis"""
if not point_list and not dimensions:
raise ValueError('either point_list or dimensions should be provided')
elif point_list:
dimensions = check_dimensionality(point_list,dimensions)
#这里每次切分直接取下个一维度,而不是取所有维度中方差最大的维度
sel_axis = sel_axis or (lambda prev_axis:(prev_axis+1) % dimensions)
if not point_list:
return KDNode(sel_axis=sel_axis,axis = axis, dimensions=dimensions)
# 对point_list 按照axis升序排列,取中位数对用的坐标点
point_list = list(point_list)
point_list.sort(key = lambda point:point[axis])
median = len(point_list) // 2
loc = point_list[median]
left = create(point_list[:median],dimensions,sel_axis(axis))
right = create(point_list[median+1:],dimensions,sel_axis(axis))
return KDNode(loc, left,right,axis = axis,sel_axis=sel_axis,dimensions=dimensions)
def check_dimensionality(point_list,dimensions=None):
"""检查并返回point_list的维度"""
dimensions = dimensions or len(point_list[0])
for p in point_list:
if len(p) != dimensions:
raise ValueError('All Points in the point_list must have the same dimensionality')
return dimensions
以上就是kd树建立的建立过程,现在来描述如何利用kd树进行搜索实现k近邻算法
- 首先我们需要建立一个优先队列,用来保存搜索到的k个最近的点.关于优先队列的具体原理,可以参考其他的教程
class BoundedPriorityQueue:
"""优先队列(max heap)及相关实现函数"""
def __init__(self, k):
self.heap=[]
self.k = k
def items(self):
return self.heap
def parent(self,index):
"""返回父节点的index"""
return int(index / 2)
def left_child(self, index):
return 2*index + 1
def right_index(self,index):
return 2*index + 2
def _dist(self,index):
"""返回index对应的距离"""
return self.heap[index][3]
def max_heapify(self, index):
"""
负责维护最大堆的属性,即使当前节点的所有子节点值均小于该父节点
"""
left_index = self.left_child(index)
right_index = self.right_index(index)
largest = index
if left_index and self._dist(left_index) >self._dist(index):
largest = left_index
if right_index and self._dist(right_index) > self._dist(largest):
largest = right_index
if largest != index :
self.heap[index], self.heap[largest] = self.heap[largest], self.heap[index]
self.max_heapify(largest)
def propagate_up(self,index):
"""在index位置添加新元素后,通过不断和父节点比较并交换
维持最大堆的特性,即保持堆中父节点的值永远大于子节点"""
while index != 0 and self._dist(self.parent(index)) < self._dist(index):
self.heap[index], self.heap[self.parent(index)] = self.heap[self.parent(index)],self.heap[index]
index = self.parent(index)
def add(self, obj):
"""
如果当前值小于优先队列中的最大值,则将obj添加入队列,
如果队列已满,则移除最大值再添加,这时原队列中的最大值、
将被obj取代
"""
size = self.size()
if size == self.k:
max_elem = self.max()
if obj[1] < max_elem:
self.extract_max()
self.heap_append(obj)
else:
self.heap_append(obj)
def heap_append(self, obj):
"""向队列中添加一个obj"""
self.heap.append(obj)
self.propagate_up(self.size()-1)
def size(self):
return len(self.heap)
def max(self):
return self.heap[0][4]
def extract_max(self):
"""
将最大值从队列中移除,同时从新对队列排序
"""
max = self.heap[0]
data = self.heap.pop()
if len(self.heap)>0:
self.heap[0]=data
self.max_heapify(0)
return max
def _search_node(self,point,k,results,get_dist):
if not self:
return
nodeDist = get_dist(self)
#如果当前节点小于队列中至少一个节点,则将该节点添加入队列
#该功能由BoundedPriorityQueue类实现
results.add((self,nodeDist))
#获得当前节点的切分平面
split_plane = self.data[self.axis]
plane_dist = point[self.axis] - split_plane
plane_dist2 = plane_dist ** 2
#从根节点递归向下访问,若point的axis维小于且分点坐标
#则移动到左子节点,否则移动到右子节点
if point[self.axis] < split_plane:
if self.left is not None:
self.left._search_node(point,k,results,get_dist)
else:
if self.right is not None:
self.right._search_node(point,k,results,get_dist)
#检查父节点的另一子节点是否存在比当前子节点更近的点
#判断另一区域是否与当前最近邻的圆相交
if plane_dist2 < results.max() or results.size() < k:
if point[self.axis] < self.data[self.axis]:
if self.right is not None:
self.right._search_node(point,k,results,get_dist)
else:
if self.left is not None:
self.left._search_node(point,k,results,get_dist)
def search_knn(self,point,k,dist=None):
"""返回k个离point最近的点及它们的距离"""
if dist is None:
get_dist = lambda n:n.dist(point)
else:
gen_dist = lambda n:dist(n.data, point)
results = BoundedPriorityQueue(k)
self._search_node(point,k,results,get_dist)
#将最后的结果按照距离排序
BY_VALUE = lambda kv: kv[1]
return sorted(results.items(), key=BY_VALUE)
以上就是本文的所有内容,上述代码的完整版可在这里获得.