KD-Tree Python实现

KD-Tree

仅以此文纪念过往的岁月

简介

KD-Tree 全称 K-Dimensional Tree ,KD-Tree本质是一种平衡二叉树,其一般应用于KNN分类。其关键是查询,对于KNN应用而言,一般是一次构造后,不再增加和删除节点。对某一节点计算其近邻的点。

KD-Tree构造

Python代码

from collections import namedtuple
from operator import itemgetter
from pprint import pformat

class Node:
def __init__(self,_sa,_loc,_lc,_rc,_p):
    self.splitAttribute = _sa
    self.location = _loc
    self.left_child = _lc
    self.right_child = _rc
    self.parent = _p
    return
def isLeft(self):
    if self.parent.left_child == self:
        return  True
    else:
        return  False
def isRight(self):
    if self.parent.right_child == self:
        return True
    else:
        return False
def isRoot(self):
    if self.parent == None:
        return True
    else:
        return False
def neghbor(self):
    if self.isRight() == True:
        return  self.parent.left_child
    elif self.isLeft() == True:
        return  self.parent.right_child
    else:
        return None

def KDTree(point_list,depth=0):
    try:
        k = len(point_list[0])
    except IndexError as e:
        return  None
    axis = depth%k

    point_list.sort(key = itemgetter(axis))
    median = len(point_list) //2
    _left_child = KDTree(point_list[:median], depth + 1)
    _right_child = KDTree(point_list[median + 1:], depth + 1)
    node = Node(axis, point_list[median], _left_child, _right_child,None)
    if node.left_child != None:
        node.left_child.parent = node;
    if node.right_child != None:
        node.right_child.parent = node
    return node
    def main():
        """Example usage"""
        point_list = [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)]
        tree = kdtree(point_list)
        test_point =(4,7);
        node =  searchKDTree(tree,test_point)

    if __name__ == '__main__':
        main()

KD-Tree查询

对于KNN需要查询K个近邻值,如果K=1,则变为最近邻值查询
1.二叉树搜索
根据splitAttribute,查询节点最接近的节点,但是该节点不一定是最近邻节点
2.回溯
得到最邻近的近似点后(找到的叶子节点不一定是最近邻),沿上述搜索路径反向考察是否有距离查询节点更近的数据点(最近邻点肯定位于与查询点为圆心且通过最近邻近似点的园域内),直至回溯到根节点。K近邻查询类似最紧邻查询,只是将最近变为K个最近。

def distancePoint(pointA,pointB):
if len(pointA) != len(pointB):
    return  -1
d = len(pointA)
count = 0
for i in xrange(d):
    count += (pointA[i]-pointB[i])**2
return  count**0.5

def searchKDTree(node,point,k=1):
    if len(node.location) != len(point):
        return None
    axis = node.splitAttribute
    value = point[axis];
    nodeT = node
    while nodeT != None:
        if value <= node.location[axis]:
            node = nodeT
            nodeT = node.left_child
        else:
            node = nodeT
            nodeT = node.right_child
    #back
    curPoint = node;
    curDis = distancePoint(curPoint.location,point)
    nodeT = node;
    while node.isRoot() != True:
        if node.neghbor() != None:
            dis = distancePoint(point,node.neghbor().location)
            if dis

上述查询在维数很大时,会出现维数灾难,故提出改进的最近邻查询 – BBF查询算法

BBF查询算法

待续

你可能感兴趣的:(机器学习)