kd树的实现(python)

当前的实现针对的是二维空间, 代码可运行

T = np.array([(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)])
#T = np.array([(3,3),(2,2),(1,2)])
#plt.scatter(T[:,0], T[:,1])

def crt_tree(T, i):
    if len(T)==1:
        return Node(T[0])
    if len(T)==0:
        return None
    dim = T.shape[1]
    axis = i % dim
    #T空间的中间个数
    midPos = T.shape[0]//2
    #T的从大到小的项目的排列的索引
    indices = T[:,axis].argsort(0)
    midItemPos = indices[T.shape[0]//2]
    node = Node(T[midItemPos])
    leftT = T[indices[:midPos]]
    rightT = T[indices[midPos+1:]]
    node.left = crt_tree(leftT, i+1)
    node.right = crt_tree(rightT, i+1)
    return node

t = crt_tree(T, 0)
pre_traversal(t)
print(T)
print(T.argsort(0))

你可能感兴趣的:(MachineLearning)