实现 k 近邻算法时,主要考虑的问题就是如何对训练数据进行快速 k 近邻搜索。
k 近邻法最简单的实现方法是线性扫描 linear scan,这需要计算输入实例与其他每个训练实例的距离,在训练集很大的时候,这种方法是不可取的(上述代码中我们使用的方法都是线性扫描)
为了提高 k 近邻搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数,比如 kd 树
kd 树是二叉树,表示对 k 维空间的一个划分,是一种对 k 维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。
构造 kd 树相当于不断地用垂直于坐标轴的超平面将 k 维空间切分,构成一些列的 k 维超矩形区域。kd 树的每一个结点对应于一个 k 维超矩形区域。
通常,依次选择坐标轴对空间切分,选择训练实例点在选定坐标轴上的中位数(median)为切分点,这样得到的 kd 树是平衡的。 注意,平衡的 kd 树搜索时的效率未必是最优的。
下面给出构造 kd 树的算法:
举例如下:
给定一个二维空间的数据集构造一个平衡 kd 树:
详细步骤如下:
首先,我们的数据集对应的特征空间如下:
根节点对应包含数据集 T 的矩形。选择 x ( 1 ) x^{(1)} x(1) 轴,6 个数据点的 x ( 1 ) x^{(1)} x(1) 坐标依次为: 2 , 5 , 9 , 4 , 8 , 7 2,5,9,4,8,7 2,5,9,4,8,7,对应的中位数为 7(中位数本该为 6,但是数据集中没有该数据点,故选 7),即根节点为 ( 7 , 2 ) (7,2) (7,2)
⭐ 通过与 x ( 1 ) x^{(1)} x(1) 垂直的轴即 y 轴进行切分( x ( 1 ) = 7 x^{(1)} = 7 x(1)=7), 将空间分为左右两个子矩形:
对应的,构造出的 kd 树如下:
接着,左矩形以 x ( 2 ) x^{(2)} x(2) 轴 的中位数进行划分。左矩形中拥有的实例点为: ( 4 , 5 ) , ( 2 , 3 ) , ( 5 , 4 ) (4,5),(2,3),(5,4) (4,5),(2,3),(5,4), 这 3 个数据点的 x ( 2 ) x^{(2)} x(2) 坐标分别为 5 , 3 , 4 5,3,4 5,3,4,中位数为 4,即根节点的左孩子为 ( 5 , 4 ) (5,4) (5,4)
通过与 x ( 2 ) x^{(2)} x(2) 垂直的轴即 x 轴进行切分( x ( 2 ) = 4 x^{(2)} = 4 x(2)=4), 将空间分为左右两个子矩形:
对应的,构造出的 kd 树如下:
同样的,根节点切分出来的右矩形也以 x ( 2 ) x^{(2)} x(2) 轴 的中位数进行划分。右矩形中拥有的实例点为: ( 8 , 1 ) , ( 9 , 6 ) (8,1),(9,6) (8,1),(9,6), 这 3 个数据点的 x ( 2 ) x^{(2)} x(2) 坐标分别为 1 , 6 1,6 1,6,中位数为 6,即根节点的右孩子为 ( 9 , 6 ) (9,6) (9,6)
通过与 x ( 2 ) x^{(2)} x(2) 垂直的轴即 x 轴进行切分( x ( 2 ) = 6 x^{(2)} = 6 x(2)=6), 将空间分为左右两个子矩形:
对应的,构造出的 kd 树如下:
OK,根节点划分出来的两个子区域处理完了,现在来看根节点的左孩子划分出来的两个子区域,每个子区域中只有一个数据点了: ( 2 , 3 ) (2,3) (2,3) 和 ( 4 , 5 ) (4,5) (4,5), 上面过程中我们依次按照 x ( 1 ) , x ( 2 ) x^{(1)}, x^{(2)} x(1),x(2) 进行选取,所以现在又回到了 x ( 1 ) x^{(1)} x(1) (如果有 $ x^{(3)}$ 则按照 $ x^{(3)}$ 进行选取)。
通过与 x ( 1 ) x^{(1)} x(1) 垂直的轴即 y 轴进行切分( x ( 1 ) = 2 x^{(1)} = 2 x(1)=2 和 x ( 1 ) = 4 x^{(1)} = 4 x(1)=4), 将空间分为左右两个子矩形:
对应的,构造出的 kd 树如下:
根节点的左孩子划分出来的两个子区域处理完了,现在来处理根节点的右孩子划分出来的两个子区域,只有一个子区域中有数据点 ( 8 , 1 ) (8,1) (8,1)。
通过与 x ( 1 ) x^{(1)} x(1) 垂直的轴即 y 轴进行切分( x ( 1 ) = 8 x^{(1)} = 8 x(1)=8), 将空间分为左右两个子矩形:
对应的,构造出的 kd 树如下:
至此,kd 树构造完毕
构造 kd 树算法的具体 Python 代码实现如下:
# kd-tree每个结点中主要包含的数据结构如下
class KdNode(object):
def __init__(self, dom_elt, split, left, right):
self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)
self.split = split # 整数(进行分割维度的序号)
self.left = left # 该结点分割超平面左子空间构成的kd-tree
self.right = right # 该结点分割超平面右子空间构成的kd-tree
class KdTree(object):
def __init__(self, data):
k = len(data[0]) # 数据维度
def CreateNode(split, data_set): # 按第split维划分数据集exset,并创建KdNode
if not data_set: # 数据集为空
return None
# key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较
# operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号
#data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
data_set.sort(key=lambda x: x[split])
split_pos = len(data_set) // 2 # //为Python中的整数除法,求中位数的下标
median = data_set[split_pos] # 中位数分割点
split_next = (split + 1) % k # 循环选取下一个划分的维度
# 递归的创建kd树
return KdNode(
median,
split,
CreateNode(split_next, data_set[:split_pos]), # 创建左子树
CreateNode(split_next, data_set[split_pos + 1:])) # 创建右子树
self.root = CreateNode(0, data) # 从第0维分量开始构建kd树,返回根节点
# KDTree的前序遍历
def preorder(root):
print(root.dom_elt)
if root.left: # 节点不为空
preorder(root.left)
if root.right:
preorder(root.right)
测试一下上述代码:
data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
kd = KdTree(data)
preorder(kd.root)
下面介绍如何利用 kd 树进行 k 近邻搜索。
给定一个目标点,搜索其近邻。首先找到包含目标点的叶节点,然后从该叶节点出发,依次回退到父节点;不断查找与目标点最邻近的节点,当确定不可能存在更近的节点时终止。
显然,利用 kd 树可以省去对大部分数据点的搜索,从而减少搜索的计算量
下面通过一个例题来说明搜索方法:
给定一个如下图所示的 kd 树,根节点为 A,树上共存储 7 个实例点,另有一个输入目标实例点 S,求 S 的最近邻。
搜索 kd 树算法的具体 Python 代码实现如下:
# 对构建好的kd树进行搜索,寻找与目标点最近的样本点:
import math
from collections import namedtuple
# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple",
"nearest_point nearest_dist nodes_visited")
def find_nearest(tree, point):
k = len(point) # 数据维度
def travel(kd_node, target, max_dist):
if kd_node is None:
return result([0] * k, float("inf"),0) # python中用float("inf")和float("-inf")表示正负无穷
nodes_visited = 1
s = kd_node.split # 进行分割的维度
pivot = kd_node.dom_elt # 进行分割的“轴”
if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)
nearer_node = kd_node.left # 下一个访问节点为左子树根节点
further_node = kd_node.right # 同时记录下右子树
else: # 目标离右子树更近
nearer_node = kd_node.right # 下一个访问节点为右子树根节点
further_node = kd_node.left
temp1 = travel(nearer_node, target, max_dist) # 递归遍历找到包含目标点的区域
nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”
dist = temp1.nearest_dist # 更新最近距离
nodes_visited += temp1.nodes_visited
if dist < max_dist:
max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内
temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离
if max_dist < temp_dist: # 判断超球体是否与超平面相交
return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断
#----------------------------------------------------------------------
# 计算目标点与分割点的欧氏距离
temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))
if temp_dist < dist: # 如果“更近”
nearest = pivot # 更新最近点
dist = temp_dist # 更新最近距离
max_dist = dist # 更新超球体半径
# 检查另一个子结点对应的区域是否有更近的点
temp2 = travel(further_node, target, max_dist)
nodes_visited += temp2.nodes_visited
if temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离
nearest = temp2.nearest_point # 更新最近点
dist = temp2.nearest_dist # 更新最近距离
return result(nearest, dist, nodes_visited)
return travel(tree.root, point, float("inf")) # 从根节点开始递归
测试一下上述代码:
ret = find_nearest(kd, [3,4.5])
print (ret)