// Node
class Node:
def __init__(self, key, value=-1):
self.left = None
self.right = None
self.key = key
self.value = value
// insert
def insert(root, key, value=-1):
if root is None:
root = Node(key, value)
else:
if key < root.key:
root.left = insert(root.left, key, value)
elif key > root.key:
root.right = insert(root.right, key, value)
else: # don't insert if key already exist in the tree
pass
return root
// search_recursively
def search_recursive(root, key):
if root is None or root.key == key:
return root
if key < root.key:
return search_recursive(root.left, key)
elif key > root.key:
return search_recursive(root.right, key)
// search_iterative
def search_iterative(root, key):
current_node = root
while current_node is not None:
if current_node.key = key:
return current_node
if key < current_node.key:
current_node = current_node.left
elif key > current_node.key:
current_node = current_node.right
return current_node
class KNNResultSet:
def __init__(self, capacity):
self.capacity = capacity
self.count = 0
self.worst_dist = 1e10
self.dist_index_list = []
for i in range(capacity):
self.dist_index_list.append(DistIndex(self.worst_dist, 0))
self.comparison_counter = 0
def size(self):
return self.count
def full(self):
return self.count == self.capacity
def worstDist(self):
return self.worst_dist
def add_point(self, dist, index):
self.comparison_counter += 1
if dist > self.worst_dist:
return
if self.count < self.capacity:
self.count += 1
i = self.count - 1
while i > 0
if self.dist_index_list[i-1].distance > dist:
self.dist_index_list[i] = copy.deepcopy(self.dist_index_list[i-1])
i -= 1
else:
break
self.dist_index_list[i].distance = dist
self.dist_index_list[i].index = index
self.worst_dist = self.dist_index_list[self,capacity-1].distance
class DistIndex:
def __init__(self, distance, index):
self.distance = distance
self.index = index
def __lt__(self, other):
return self.distance < other.distance
def knn_search(root: Node, result_set: KNNResultSet, key):
if root is None:
return False
# compare the root itself
result_set.add_point(math.fabs(root.key - key), root.value)
if result_set.worstDist() == 0:
return True
if root.key >= key:
# iterate left branch first
if knn_search(root.left, result_set, key):
return True
elif math.fabs(root.key - key) < result_set.worstDist():
return knn_search(root.right, result_set, key)
return False
else:
# iterate right branch first
if knn_search(root.right, result_set, key):
return True
elif math.fabs(root.key - key) < result_set.worstDist():
return knn_search(root.left, result_set, key)
return False
def add_point(self, dist, index):
self.comparison_counter += 1
if dist > self.radius:
return
self.count += 1
self.dist_index_list.append(DistIndex(dist, index))
def radius_search(root: Node, result_set: RadiusNNResultSet, key):
if root is None:
return False
# compare the root itself
result_set.add_point(math.fabs(root.key - key), root.value)
if root.key >= key:
# iterate left branch first
if radius_search(root.left, result_set, key):
return True
elif math.fabs(root.key - key) < result_set.worstDist():
return radius_search(root.right, result_set, key)
return False
else:
# iterate right branch first
if radius_search(root.right, result_set, key):
return True
elif math.fabs(root.key - key) < result_set.worstDist():
return radius_search(root.left, result_set, key)
return False
class Node:
def __init__(self, axis, value, left, right, point_indices):
self.axis = axis
self.value = value
self.left = left
self.right = right
self.point_indices = point_indices
def is_leaf(self):
if self.value is None:
return True
else:
return False
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
"""
:param root:
:param db: N×D
:param db_sorted_idx_inv: N×D
:param point_idx: M
:param axis: scalar
:param leaf_size: scalar
:return:
"""
if root is None:
root = Node(axis, None, None, None,point_indices)
# determine whether to split into left and right
if len(point_indices) > leaf_size:
# --- get the split position ---
# sotr the points in this node, get the median position
point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis]) # M
middle_left_idx = math.ceil(point_indices_sorted.shape[0] / 2) - 1
middle_left_point_idx = point_indices_sorted[middle_left_idx]
middle_left_point_value = db[middle_left_point_idx, axis]
middle_right_idx = middle_left_idx + 1
middle_right_point_idx = point_indices_sorted[middle_right_idx]
middle_right_point_value = db[middle_right_point_idx, axis]
root.value = (middle_left_point_value + middle_right_point_value) * 0.5
# === get the split position ===
root.left = kdtree_recursive_build(root.left,
db,
point_indices_sorted[0:middle_right_idx],
axis_round_robin(axis, dim=db.shape[1]),
leaf_size)
root.right = kdtree_recursive_build(root.right,
db,
point_indices_sorted[middle_right_idx:],
axis_round_robin(axis, dim=db.shape[1]),
leaf_size)
return root
def axis_round_robin(axis, dim):
if axis == dim - 1:
return 0
else:
return axis + 1
def knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query:np.ndarry):
if root is None:
return False
# Compare query to every point inside the leaf, put into the result set
if root is leaf():
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis = 1)
# update worstDist()
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
return False
# q[axis] inside the partition
if query[root.axis] <= root.value:
knn_search(root.left, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
knn_search(root.right, db, result_set, query)
else:
knn_search(root.right, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
knn_search(root.left, db, result_set, query)
return False
# RadiusNN
if query[root.axis] < root.value
radius_search(root.left, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
radius_search(root.right, db, result_set, query)
else:
radius_search(root.right, db, result_set, query)
if math.fabs(query[root.axis] - root.value) < result_set.worstDist():
radius_search(root.left, db, result_set, query)
return False
class Octant:
def __init__(self, children, center, extent, point_indices, is_leaf):
self.children = children # Array of length 8
self.center = center # Center of the cube
self.extent = extent # 0.5 * length
self.point_indices = point_indices # Point inside octant
self.is_leaf = is_leaf
def octree_recursive_build(root, db, center, extent, point_indices, leaf_size, min_extent)
if len(point_indices) == 0:
return None
if root is None:
root = Octant([None for i in range(8)], center, extent, point_indices, is_leaf=True)
# determine whether to split this octant
if len(point_indices) <= leaf_size or extent <= min_extent:
root.is_leaf = True
else:
root.is_leaf = False
children_point_indices = [[] for i in range(8)]
# Determine which child a point belongs to
for point_idx in point_indices:
point_db = db[point_idx]
morton_code = 0
if point_db[0] < center[0]:
morton_code = morton_code | 1
if point_db[1] < center[1]:
morton_code = morton_code | 2
if point_db[2] < center[2]:
morton_code = morton_code | 4
children_point_indices[morton_code].append(point.idx)
# create children
factor = [-0.5, 0.5]
# Determin child center & extent
for i in range(8):
child_center_x = center[0] + factor[(i & 1) > 0] * extent
child_center_y = center[1] + factor[(i & 2) > 0] * extent
child_center_z = center[0] + factor[(i & 4) > 0] * extent
child_extent = 0.5 * extent
child_center = np.asarray([child_center_x, child_center_y, child_center_z])
root.children[i] = octree_recursive_build(root.children[i],
db,
child_center,
child_extent,
children_point_indices[i],
leaf_size,
min_extent)
return root
def octree_knn_search(root: Octant, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf and len(root.point_indices) > 0:
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis = 1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# check whether we can stop search now
return inside(query, result_set.worstDist(), root)
# go to the relevant child first
morton_code = 0
if query[0] > root.center[0]:
morton_code = morton_code | 1
if query[1] > root.center[1]:
morton_code = morton_code | 2
if query[2] > root.center[2]:
morton_code = morton_code | 4
if octree_knn_search(root.children[morton_code], db, result_set, query):
return True
# check other children
for c, child in enumerate(root.children):
if c == morton_code or child is None:
continue
# If an octant is not overlapping with query ball, skip
if False == overlaps(query, result_set.worstDist(), child):
continue
if octree_knn_search(child, db, result_set, query):
return True
# final check of if we can stop search
return inside(query, result_set.worstDist(), root)
def inside(query: np.ndarray, radius: float, octant: Octant):
"""
Determine if the query ball is inside the octant
:param query
:param radius
:param octant
:return:
"""
query_offset = query - octant.center
query_offset_abs = np.fabs(query_offset)
possible_space = query_offset_abs + radius
return np.all(possible_space < octant.extent)
def overlaps(query: np.ndarray, radius: float, octant: Octant):
"""
Determines if the query ball overlaps with the octatn
:param query
:param radius
:param octant
:return:
"""
query_offset = query - octant.center
query_offset_abs = np.fabs(query_offset)
# completely outside, since query is outside the relevant area
max_dist = radius + octant.extent
if np.any(query_offset_abs > max_dist):
return False
# if pass the above check, consider the case that the ball is contacting the face of the octant
if np.sum((query_offset_abs < octant.extent).astype(np.int)) >= 2:
return True
# conside the case that the ball is contacting the edge or corner of the octant
# since the case of the ball center (query) inside octant has been considered,
# we only consider the ball center (query) outside octant
x_diff = max (query_offset_abs[0] - octant.extent, 0)
y_diff = max (query_offset_abs[1] - octant.extent, 0)
z_diff = max (query_offset_abs[2] - octant.extent, 0)
return x_diff * x_diif + y_diff * y_diff + z_diff * z_diff < radius * radius
def contains(query: np.ndarray, radius: float, octant:Octant):
"""
Determine if the query ball contains the octant
:param query:
:param radius:
:param octant
:return:
"""
query_offset = query - octant.center
query_offset_abs = np.fabs(query_offset)
query_offset_to_farthest_corner = query_offset_abs + octant.extent
return np.linalg.norm(query_offset_to_farthest_corner) < radius
Russel A. Brown, Journal of Computer Graphics Techniques, 2015 ↩︎