之前介绍的KNN算法使用的是线性扫描,计算复杂度和空间复杂度都很高。
事实上,实际数据集中的点一般时呈簇状分布的,所以,很多点我们是完全没有必要遍历的,索引树的方法就是对将要搜索的点进行空间划分,空间划分可能会有重叠,也可能没有重叠,kd-tree就是划分空间没有重叠的索引树
kd-tree是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。主要应用于多维空间关键数据的搜索(如:范围搜索和最近邻搜索)。
kd-tree每个节点中主要包含的数据结构如下
其Python定义为
class KD_node: def __init__(self, elt=None, split=None, LL = None, RR = None): """ elt:数据点 split:划分域 LL, RR:节点的左儿子跟右儿子 """ self.elt = elt self.split = split self.left = LL self.right = RR
样本集T由kd-tree的结点的集合表示,每个结点表示一个样本点,dom_elt就是表示该样本点的向量。该样本点根据结点的分割超平面将样本空间分为两个子空间。左子空间中的样本点集合由左子树left表示,右子空间中的样本点集合由右子树right表示。分割超平面是一个通过点dom_elt并且垂直于split所指示的方向轴的平面。
关于分裂维度split的选择,其实有很多方法,相较统计机器学习上提到的方法应用更广泛的是基于方差的选择方法:即尽可能将相似的点放在一颗子树里面,所以kd-tree采取的思想就是计算所有数据点在每个维度上的数值的方差,然后方差最大的维度就作为当前节点的划分维度。这样做的原理其实就是:方差越大,说明这个维度上的数据波动越大,也就说明了他们就越不可能属于同一个空间,需要在这个维度上对点进行划分,这就是kd-tree节点选择划分维度的原理。
从而Python创建kd-tree的程序可以写作
def createKDTree(root, data_list): """ root:当前树的根节点 data_list:数据点的集合(无序) return:构造的KDTree的树根 """ LEN = len(data_list) if LEN == 0: return #数据点的维度 dimension = len(data_list[0]) #方差 max_var = 0 #最后选择的划分域 split = 0; for i in range(dimension): ll = [] for t in data_list: ll.append(t[i]) var = computeVariance(ll) if var > max_var: max_var = var split = i #根据划分域的数据对数据点进行排序 data_list.sort(key=lambda x: x[split]) #选择下标为len / 2的点作为分割点 elt = data_list[LEN / 2] root = KD_node(elt, split) root.left = createKDTree(root.left, data_list[0:(LEN / 2)]) root.right = createKDTree(root.right, data_list[(LEN / 2 + 1):LEN]) return root def computeVariance(arrayList): """ arrayList:存放的数据点 return:返回数据点的方差 """ for ele in arrayList: ele = float(ele) LEN = float(len(arrayList)) array = numpy.array(arrayList) sum1 = array.sum() array2 = array * array sum2 = array2.sum() mean = sum1 / LEN #D[X] = E[x^2] - (E[x])^2 variance = sum2 / LEN - mean**2 return variance
基本的查找思路如下:
2.1 二叉查找从根节点开始进行查找,直到叶子节点;在这个过程中,记录最短的距离,和对应的数据点;同时维护一个栈,用来存储经过的节点
2.回溯查找通过计算查找点到分割平面的距离(这个距离比较的是分割维度上的值的差,并不是分割节点到分割平面上的距离,虽然两者的值是相等的)与当前最短距离进行比较,决定是否需要进入节点的相邻子空间进行查找。
def findNN(root, query): """ root:KDTree的树根 query:查询点 return:返回距离data最近的点NN,同时返回最短距离min_dist """ #初始化为root的节点 NN = root.elt min_dist = computeDist(query, NN) nodeList = [] temp_root = root ##二分查找建立路径 while temp_root: nodeList.append(temp_root) dd = computeDist(query, temp_root.elt) if min_dist > dd: NN = temp_root.elt min_dist = dd #当前节点的划分域 ss = temp_root.split if query[ss] <= temp_root.elt[ss]: temp_root = temp_root.left else: temp_root = temp_root.right ##回溯查找 while nodeList: #使用list模拟栈,后进先出 back_elt = nodeList.pop() ss = back_elt.split print "back.elt = ", back_elt.elt ##判断是否需要进入父亲节点的子空间进行搜索 if abs(query[ss] - back_elt.elt[ss]) < min_dist: if query[ss] <= back_elt.elt[ss]: temp_root = back_elt.right else: temp_root = back_elt.left if temp_root: nodeList.append(temp_root) curDist = computeDist(query, temp_root.elt) if min_dist > curDist: min_dist = curDist NN = temp_root.elt return NN, min_dist def computeDist(pt1, pt2): """ 计算两个数据点的距离 return:pt1和pt2之间的距离 """ sum = 0.0 for i in range(len(pt1)): sum = sum + (pt1[i] - pt2[i]) * (pt1[i] - pt2[i]) return math.sqrt(sum)