本文是《统计学习方法》的第三章,包含k近邻算法的原理与python实现。希望自己能坚持下去,完成整本书的学习
k近邻是一种基本的分类与回归方法。本文只讨论分类问题中的k近邻算法。k近邻算法的输入为实例的特征向量,对于输入的实例,可以取多类。分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方式进行预测。因此,K近邻法不具有显示的学习过程。k近邻实际上是利用训练集对特征空间进行划分,并作为其分类的“模型”。k值的选择、距离度量及分类决策规则是k近邻法的三个基本要素。
k值的选择会对k近邻算法的结果产生重大影响。
k较小,容易被噪声影响,发生过拟合。结果受临近的几个点的影响会很大,估计误差会增大。
k较大,学习的近似误差会增大,与输入实例距离较远的实例也会对预测起作用,使预测发生错误。k较大相当于模型变得简单。
k近邻算法最简单的实现方法是线性扫描,这时要计算输入实例与每一个训练实例的距离。当训练集很大的时候,计算非常耗时,这种方法是不可行的。
为了提高k近邻算法搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数。下面介绍这些方法中的一种,kd树。
对数据集T中的子集S初始化S=T,取当前节点node=root取维数的序数i=0,对S递归执行:
找出S的第i维的中位数对应的点,通过该点,且垂直于第i维坐标轴做一个超平面。该点加入node的子节点。该超平面将空间分为两个部分,对这两个部分分别重复此操作(S=S’,++i,node=current),直到不可再分。
下面是python代码实现:
T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
class node:
def __init__(self,point,split):
self.left=None
self.right=None
self.parent=None
self.point=point
self.split=split
pass
def set_left(self,node):
if node==None:
pass
self.left=node
node.parent=self
def set_right(self,node):
if node==None:
pass
self.right=node
node.parent=self
def median(data):
m=len(data)//2
return data[m],m
def build_kdtree(data, d):
data = sorted(data, key=lambda x: x[d])
p, m = median(data)
tree = node(p, d)
del data[m]
if m > 0: tree.set_left(build_kdtree(data[:m], not d))
if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d))
return tree
def distance(a, b):
print (a, b)
return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
def search_kdtree(tree, d, target, root):
if target[d] < tree.point[d]:
if tree.left != None:
return search_kdtree(tree.left, not d, target, root)
else:
if tree.right != None:
return search_kdtree(tree.right, not d, target, root)
def update_best(t, best):
if t == None: return
t = t.point
d = distance(t, target)
if d < best[1]:
best[1] = d
best[0] = t
return
best = [tree.point, distance(tree.point, target)]
while (tree.parent != None and tree != root):
split = tree.parent.split
if(best[1] > abs(target[split] - tree.parent.point[split])):
update_best(tree.parent, best)
tempBest = None
if(tree.point[split] < tree.parent.point[split]):
if(tree.parent.right != None):
tempBest = search_kdtree(tree.parent.right, tree.parent.right.split, target, tree.parent.right)
else:
if(tree.parent.left != None):
tempBest = search_kdtree(tree.parent.left, tree.parent.left.split, target, tree.parent.left)
if(tempBest != None and tempBest[1] < best[1]):
best = tempBest
tree = tree.parent
return best
kd_tree = build_kdtree(T, 0)
print (search_kdtree(kd_tree, 0, [9, 4], kd_tree))
搜索是一个递归的过程。先直接到叶节点,然后找到目标点的插入位置,然后往上走,逐步用自己到目标点的距离画个超球体,用超球体圈住的点来更新最近邻(或k最近邻)。
输出结果如下:
[8, 1] [9, 4]
[9, 6] [9, 4]
[[9, 6], 2.0]
图中仅用了两次搜索,便查出了距离最近的点,因此可以看出kd树是一个性能优越的数据结构。