本篇博客是参考李航的《统计学习方法》第三章3.3,使用python实现kd树的建立。kd树本质上就是平衡二叉树,只是注意要选择中位数作为节点即可
class kdNode:
def __init__(self, data = None, depth = 0,left = None, right= None):
self.data = data
self.depth = depth
self.left = left
self.right = right
def travel(self):
if not self.data:
return
print self.data,self.depth
self.left.travel()
self.right.travel()
def build_tree(self,points, depth):
if not points:
return
dims = len(points[0])
axis = depth % dims
points.sort(key = lambda x:x[axis])
median_index = len(points) // 2
self.data = points[median_index]
self.left = kdNode(None,depth+1)
self.right = kdNode(None,depth+1)
self.left.build_tree(points[:median_index],depth+1)
self.right.build_tree(points[median_index+1:],depth+1)
class kdTree:
def __init__(self, data = None,depth = 0,left = None, right = None):
self.root = kdNode(data,depth,left,right)raise ValueError('could not find %c in %s' % (ch,str))
#build the tree by input points
def build_tree(self, points, depth):
self.root.build_tree(points,depth)
#preorder travelkd_tree.travel_tree()
def travel_tree(self):
self.root.travel()
kd_tree = kdTree()
points = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
kd_tree.build_tree(points,0)
kd_tree.travel_tree()
[7, 2] 0
[5, 4] 1
[2, 3] 2
[4, 7] 2
[9, 6] 1
[8, 1] 2