KD-tree在二叉树的基础上,实际上就是变成了多维度的分割,分割方法变为维度轮换分割:x-y-z-x-y-z…
1.kd-tree的构建
(1)节点定义
每个节点按我的理解其实就是单维度上的一片空间区域,该节点储存了该节点的分割维度axis,分割轴的坐标value,该节点区域内的点的索引point_indices,以分割轴分开的左右节点也是两个区域,被当做left、right两个节点存储起来。
from result_set import KNNResultsets, RadiusNNResultSet
import numpy as np
import math
import os
import random
# define the node
class Node:
# init function
def __init__(self, axis, value, left, right, point_indices):
'''
pram:
axis:划分的维度
value:节点的值
left:左节点
right:右节点
point_indices:点集中的序号
'''
self.axis = axis
self.value = value
self.left = left
self.right = right
self.point_indices = point_indices
# check if is leaf
def is_leaf(self):
#叶子节点的时候是不再对空间进行划分的,所以value是none
if self.value == None:
return True
else:
return False
# print function
def __str__(self):
output = ''
output += 'axis %d, ' % self.axis
if self.value is None:
output += 'split value: leaf, '
else:
output += 'split value: %.2f, ' % self.value
output += 'point_indices: '
output += str(self.point_indices.tolist())
return output
(2)kd树构建
kdtree_construction调用递归构造函数kdtree_recursive_builder构建kd-tree节点
def kdtree_construction(data, leaf_size):
'''
input:
data(numpy array)
leaf_size(the smallest size of each node)
'''
N, dim = data.shape[0], data.shape[1]
# build the kd-tree
root = None
root = kdtree_recursive_builder(
root, data, np.arange(N), axis=0, leaf_size=leaf_size)
return root
kdtree_recursive_builder
def kdtree_recursive_builder(root, db, point_indices, axis, leaf_size):
"""
param:
root
data:NxD
point_indeces:N
axis:scale
leaf_size:scale
"""
#如果root不存在,先建立root
if root is None:
root = Node(axis, None, None, None, point_indices)
#如果当前的节点数量大于叶子节点数,才进一步的进行分割,否则就不进行进一步分割,当前节点就作为叶子结点
#叶子结点特点就是value==None
if len(point_indices) > leaf_size:
# --- get the split position ---
#对当前传入的数据节点在划分维度上进行排序,选出当前维度的中间值数据点
#划分点的value就等于中间值数据点的均值,注意此处划分的中间平面不穿过数据点
point_indices_sorted, _ = sort_key_by_value(
point_indices, db[point_indices, axis]) # 点的索引按照value的大小顺序进行排序
#求出当前维度下中间值的点的索引位置
middle_left_idx = math.ceil(point_indices_sorted.shape[0] / 2) - 1
#中间点在原来点集合中的索引
middle_left_point_idx = point_indices_sorted[middle_left_idx]
#中间点的value值
middle_left_point_value = db[middle_left_point_idx, axis]
#中间点后一个点也一样
middle_right_idx = middle_left_idx + 1
middle_right_point_idx = point_indices_sorted[middle_right_idx]
middle_right_point_value = db[middle_right_point_idx, axis]
root.value = (middle_left_point_value + middle_right_point_value) * 0.5#取中间两个数据点value的平均值,不穿过数据点
# === get the split position ===
root.left = kdtree_recursive_builder(root.left,
db,
point_indices_sorted[0:middle_right_idx],
axis_round_robin(
axis, dim=db.shape[1]),
leaf_size)
root.right = kdtree_recursive_builder(root.right,
db,
point_indices_sorted[middle_right_idx:],
axis_round_robin(
axis, dim=db.shape[1]),
leaf_size)
return root
sort_key_by_value排序函数:
# sort_key_by_value
def sort_key_by_value(point_indeces, value):
# 确保输入的点的索引和点的值维度相同
assert(point_indeces.shape == value.shape)
assert(len(point_indeces.shape) == 1) # 1xN
sort_idx = np.argsort(value) # value是一个列表,不是numpy
point_indeces_sort = point_indeces[sort_idx]
value_sort = value[sort_idx]
return point_indeces_sort, value_sort
axis_round_robin函数,轮换确定分割维度:0-->1,1--》2,2-->0
# axis_round_robin,改变分割维度
def axis_round_robin(axis, dim):
if axis == dim-1:
return 0
else:
return axis+1
2.kd-tree的KNN查找
和二叉树的查找很相似,不同之处在于初始时刻要判断当前节点是不是叶子节点,是的话就直接将各个节点计算和当前查询点的距离,把这些点插入到结果集合中。
不是叶子节点的时候:也是先比较查询点和当前节点的value的大小,选择从左边找还是从右边找。
是否查找一个节点区域的关键是:判断当前最坏距离与(查询点坐标—节点区域分割坐标之差)之间的大小关系。
def kdtree_knn_search(root: Node, data: np.ndarray, result_set: KNNResultsets, query_point: np.ndarray):
if root == None:
return
# check if is a leaf
#当当前节点是叶子节点的时候,由于不能再进一步的在当前节点空间上区分左右子空间了,
#所以就暴力的把所有叶子节点中的节点都拿出来和查询点计算距离,把距离较近的点都插入到result
if root.is_leaf():
# compare the contents of a leaf
leaf_points = data[root.point_indices, :]#root.point_indeces是一个列表,存储的是点在元数据中的索引
#print("leaf index:", root.point_indices)
diff = np.linalg.norm(np.expand_dims(
query_point, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
#距离小于root value,从左边开始找
if query_point[root.axis] <= root.value:
kdtree_knn_search(root.left, data, result_set, query_point)
#如果左边没有找够,就继续从右边找
if math.fabs(query_point[root.axis] - root.value) < result_set.worst_Dist():
kdtree_knn_search(root.right, data, result_set, query_point)
else:
#从右边开始找
kdtree_knn_search(root.right, data, result_set, query_point)
if math.fabs(query_point[root.axis] - root.value) < result_set.worst_Dist():
kdtree_knn_search(root.left, data, result_set, query_point)
return False
3.kd-tree的radius查找
def kdtree_radius_search(root: Node, data: np.ndarray, result_set: RadiusNNResultSet, query_point: np.ndarray):
if root == None:
return
# check if is a leaf
if root.is_leaf():
# compare the contents of a leaf
leaf_points = data[root.point_indices, :]
print("leaf index:", root.point_indices)
diff = np.linalg.norm(np.expand_dims(
query_point, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
if query_point[root.axis] <= root.value:
kdtree_radius_search(root.left, data, result_set, query_point)
if math.fabs(query_point[root.axis] - root.value) < result_set.worst_Dist():
kdtree_radius_search(root.right, data, result_set, query_point)
else:
kdtree_radius_search(root.right, data, result_set, query_point)
if math.fabs(query_point[root.axis] - root.value) < result_set.worst_Dist():
kdtree_radius_search(root.left, data, result_set, query_point)
return False
main函数调用:
def main():
# generate the data
db = 64
dim = 3
leaf_size = 4
k = 8
print('runing')
db_np = np.random.rand(db, dim)
root = kdtree_construction(data=db_np, leaf_size=leaf_size)
query = np.asarray([0, 0, 0])
# result_set = KNNResultsets(capacity=k)
# kdtree_knn_search(root, data=db_np,
# result_set=result_set, query_point=query)
# print(result_set)
result_set = RadiusNNResultSet(radius=1)
kdtree_radius_search(
root, data=db_np, result_set=result_set, query_point=query)
print(result_set)
if __name__ == '__main__':
main()
python相关:
diff = np.linalg.norm(np.expand_dims(
query_point, 0) - leaf_points, axis=1)
#计算范数:在列的方向上计算,也就是计算每个点和查询点之间的距离(空间距离)