仅以此文纪念过往的岁月
KD-Tree 全称 K-Dimensional Tree ,KD-Tree本质是一种平衡二叉树,其一般应用于KNN分类。其关键是查询,对于KNN应用而言,一般是一次构造后,不再增加和删除节点。对某一节点计算其近邻的点。
from collections import namedtuple
from operator import itemgetter
from pprint import pformat
class Node:
def __init__(self,_sa,_loc,_lc,_rc,_p):
self.splitAttribute = _sa
self.location = _loc
self.left_child = _lc
self.right_child = _rc
self.parent = _p
return
def isLeft(self):
if self.parent.left_child == self:
return True
else:
return False
def isRight(self):
if self.parent.right_child == self:
return True
else:
return False
def isRoot(self):
if self.parent == None:
return True
else:
return False
def neghbor(self):
if self.isRight() == True:
return self.parent.left_child
elif self.isLeft() == True:
return self.parent.right_child
else:
return None
def KDTree(point_list,depth=0):
try:
k = len(point_list[0])
except IndexError as e:
return None
axis = depth%k
point_list.sort(key = itemgetter(axis))
median = len(point_list) //2
_left_child = KDTree(point_list[:median], depth + 1)
_right_child = KDTree(point_list[median + 1:], depth + 1)
node = Node(axis, point_list[median], _left_child, _right_child,None)
if node.left_child != None:
node.left_child.parent = node;
if node.right_child != None:
node.right_child.parent = node
return node
def main():
"""Example usage"""
point_list = [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)]
tree = kdtree(point_list)
test_point =(4,7);
node = searchKDTree(tree,test_point)
if __name__ == '__main__':
main()
对于KNN需要查询K个近邻值,如果K=1,则变为最近邻值查询
1.二叉树搜索
根据splitAttribute,查询节点最接近的节点,但是该节点不一定是最近邻节点
2.回溯
得到最邻近的近似点后(找到的叶子节点不一定是最近邻),沿上述搜索路径反向考察是否有距离查询节点更近的数据点(最近邻点肯定位于与查询点为圆心且通过最近邻近似点的园域内),直至回溯到根节点。K近邻查询类似最紧邻查询,只是将最近变为K个最近。
def distancePoint(pointA,pointB):
if len(pointA) != len(pointB):
return -1
d = len(pointA)
count = 0
for i in xrange(d):
count += (pointA[i]-pointB[i])**2
return count**0.5
def searchKDTree(node,point,k=1):
if len(node.location) != len(point):
return None
axis = node.splitAttribute
value = point[axis];
nodeT = node
while nodeT != None:
if value <= node.location[axis]:
node = nodeT
nodeT = node.left_child
else:
node = nodeT
nodeT = node.right_child
#back
curPoint = node;
curDis = distancePoint(curPoint.location,point)
nodeT = node;
while node.isRoot() != True:
if node.neghbor() != None:
dis = distancePoint(point,node.neghbor().location)
if dis
上述查询在维数很大时,会出现维数灾难,故提出改进的最近邻查询 – BBF查询算法
待续