平衡kd树构建算法流程:
输入:K维空间数据集T={x1, x2, …. xn},其中 xi={xi(1), xi(2), … xi(k)}, i=1,….N
输出:kd树
1:开始:构造根结点,根结点对应于包含T的k维空间的超矩形区域。选择x(1)为坐标轴,以T中所有实例的x(1)坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴x(1)垂直的超平面实现。由根结点生成深度为1的左、右子结点:左子结点对应坐标x(1)小于切分点的子区域,右子结点对应于坐标x(1)大于切分点的子区域。将落在切分超平面上的实例点保存在根结点。
2:重复。对深度为j的结点选择x(l)为切分的坐标轴,l=j%k+1,以该结点的区域中所有实例的x(l)坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴x(l)垂直的超平面实现。由该结点生成深度为j+1的左、右子结点:左子结点对应坐标x(l)小于切分点的子区域,右子结点对应坐标x(l)大于切分点的子区域。将落在切分超平面上的实例点保存在该结点。
其中k的选择有特征xi的维度或knn中k的值。
kd树的构建!例题3.2
给定一个二维空间的数据集:
T = {(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}, 构造一个平衡kd树。
为了方便,我这里进行编号A(2,3)、B(5,4)、C(9,6)、D(4,7)、E(8,1)、F(7,2)
初始值r=0,对应x轴。
首先先沿 x 坐标进行切分,我们选出 x 坐标的中位点,获取最根部节点的坐标,对数据点x坐标进行排序得:
A(2,3)、D(4,7)、B(5,4)、F(7,2)、E(8,1)、C(9,6)
则我们得到中位点为B或者F,我这里选择F作为我们的根结点,并作出切分(并得到左右子树),如图:
python代码:
R = 3 # Feature dimension
class Node(object):
def __init__(self, elem, lchild=None, rchild=None):
self.elem = elem
self.lchild = lchild
self.rchild = rchild
def KDT(root, LR, dataSource, r):
if dataSource==[]:
return
data = sorted(dataSource, key=lambda x:x[r])
r = (r+1) % R
length = len(data)
node = Node(data[length/2], None, None)
if LR==0:
root.lchild = node
KDT(root.lchild, 0, data[:length/2], r)
KDT(root.lchild, 1, data[length/2 + 1:], r)
if LR==1:
root.rchild = node
KDT(root.rchild, 0, data[:length/2], r)
KDT(root.rchild, 1, data[length/2 + 1:], r)
def InitTree(dataSource, length):
r = 0
if dataSource==[]:
print "Please init dataSource."
return None
data = sorted(dataSource, key=lambda x:x[r])
r=(r+1) % R
root = Node(data[length/2], None, None)
KDT(root, 0, data[:length/2], r)
KDT(root, 1, data[length/2 + 1:], r)
print "InitTree Done."
return root
def PreOrderTraversalTree(root):
if root:
print root.elem,' | ',
PreOrderTraversalTree(root.lchild)
PreOrderTraversalTree(root.rchild)
def InOrderTraversalTree(root):
if root:
InOrderTraversalTree(root.lchild)
print root.elem,' | ',
InOrderTraversalTree(root.rchild)
def PostOrderTraversalTree(root):
if root:
PostOrderTraversalTree(root.lchild)
PostOrderTraversalTree(root.rchild)
print root.elem,' | ',
if __name__ == "__main__":
dataSource = [(2,3, 100), (5,4,70), (9,6,55), (4,7,200), (8,1,44), (7,2,0)]
length = len(dataSource)
root = InitTree(dataSource, length)
print "PreOrder:"
PreOrderTraversalTree(root)
print "\nInOrder:"
InOrderTraversalTree(root)
print "\nPostOrder:"
PostOrderTraversalTree(root)
参考资料:
https://mshj.blog.ustc.edu.cn/?p=292
https://www.suilengea.com/show/xazcegzcv.html
李航 统计学习方法