kd-tree的python实现

本文主要内容

[ kD-tree的C语言实现 ]是多年前写过的一篇kd-tree的博客。当时正在看李航老师的《统计学习方法》一书,看到kNN算法和kd-tree之间的关系,非常有兴趣进行深入了解,所以汇总了一些资料,后面由于实际工作中用不到,就放下了。最近重新翻了翻李老师的这本书,发现现在的理解比以前深了很多,而且这种经典是常看常新的,每多翻一次,就多一分收获。

本文主要内容:
- kNN算法与kd-tree的关系
- kd-tree的构建
- kd-tree的近邻查找


kNN算法与kd-tree的关系

K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。
但是KNN算法使用的是线性扫描,计算复杂度和空间复杂度都很高。
事实上,实际数据集中的点一般时呈簇状分布的,所以,很多点我们是完全没有必要遍历的,索引树的方法就是对将要搜索的点进行空间划分,空间划分可能会有重叠,也可能没有重叠,kd-tree就是划分空间没有重叠的索引树
kd-tree是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。主要应用于多维空间关键数据的搜索(如:范围搜索和最近邻搜索)。

构造kd-tree

kd-tree每个节点中主要包含的数据结构如下

域名 类型 描述
dom_elt kd维的向量 kd维空间中的一个样本点
split 整数 分裂维的序号,也是垂直于分割超面的方向轴序号
left kd-tree 由位于该结点分割超面左子空间内所有数据点构成的kd-tree
right kd-tree 由位于该结点分割超面右子空间内所有数据点构成的kd-tree

对应的python定义为:

class  KD_node:
    def __init__(self, elt=None, split=None, LL=None, RR=None):
        '''  
        elt:数据点  
        split:划分域  
        LL, RR:节点的左儿子跟右儿子  
        '''
        self.elt = elt
        self.split = split
        self.left = LL
        self.right = RR
    def createKDTree(root, data_list):
        '''
        root:当前树的根节点  
        data_list:数据点的集合(无序)  
        return:构造的KDTree的树根  
        '''
        LEN = len(data_list)
        if LEN == 0:
            return
        # 数据点的维度
        dimension = len(data_list[0])
        # 方差
        max_var = 0
        # 最后选择的划分域
        split = 0
        for i in range(dimension):
            items = []
            for t in data_list:
                items.append(t[i])
            var = computeVariance(items)
            if var > max_var:
                max_var = var
                split = i
        #根据划分域的数据对数据点进行排序 
        data_list.sort(key=lambda x: x[split])
        #选择下标为len / 2的点作为分割点
        elt = data_list[LEN/2]
        root = KD_node(elt,split)
        root.left = createKDTree(root.left, data_list[0:LEN/2])
        root.right = createKDTree(root.right, data_list[(LEN/2+1):LEN])
        return root

def computeVariance(arrayList):
    '''
    arrayList:存放的数据点  
    return:返回数据点的方差  
    '''
    for ele in arrayList:
        ele = float(ele)
    LEN = float(len(arrayList))
    array = numpy.array(arrayList)
    sum1 =  array.sum()
    array2 = array * array
    sum2 = array2.sum()
    mean = sum1 / LEN
    #    D[X] = E[x^2] - (E[x])^2  
    variance = sum2 / LEN - mean**2
    return variance      

每个结点表示一个样本点,dom_elt就是表示该样本点的向量。该样本点根据结点的分割超平面将样本空间分为两个子空间。左子空间中的样本点集合由左子树left表示,右子空间中的样本点集合由右子树right表示。分割超平面是一个通过点dom_elt并且垂直于split所指示的方向轴的平面。
关于分裂维度split的选择,其实有很多方法,相较统计机器学习上提到的方法应用更广泛的是基于方差的选择方法:即尽可能将相似的点放在一颗子树里面,所以kd-tree采取的思想就是计算所有数据点在每个维度上的数值的方差,然后方差最大的维度就作为当前节点的划分维度。这样做的原理其实就是:方差越大,说明这个维度上的数据波动越大,也就说明了他们就越不可能属于同一个空间,需要在这个维度上对点进行划分,这就是kd-tree节点选择划分维度的原理。

kd-tree的最近邻查找

基本的查找思路如下:

二叉查找

从根节点开始进行查找,直到叶子节点;在这个过程中,记录最短的距离,和对应的数据点;同时维护一个栈,用来存储经过的节点

回溯查找

通过计算查找点到分割平面的距离(这个距离比较的是分割维度上的值的差,并不是分割节点到分割平面上的距离,虽然两者的值是相等的)与当前最短距离进行比较,决定是否需要进入节点的相邻子空间进行查找。

def findNN(root, query):
    '''
    root:KDTree的树根  
    query:查询点  
    return:返回距离data最近的点NN,同时返回最短距离min_dist  
    '''

    NN = root.elt
    min_dist = computeDist(query,NN)
    nodeList = []
    temp_root = root

    ## 二分查找建立路径
    while temp_root:
        nodeList.append(temp_root)
        dist = computeDist(query,temp_root.elt)
        if min_dist > dist:
            NN = temp_root.elt
            min_dist = dist

        # 当前节点的划分域
        splt = temp_root.split
        if query[splt] <= temp_root.elt[splt]:
            temp_root = temp_root.left
        else:
            temp_root = temp_root.right

        # 回溯查找    
        while nodeList:
            #使用list模拟栈,后进先出   
            back_elt = nodeList.pop()
            splt = back_elt.split
            print("back.elt = ", back_elt.elt)
            ## 判断是否需要进入父亲节点的子空间进行搜索 
            if abs(query[splt] - back_elt.elt[splt]) < min_dist:
                if(query[splt] <= back_elt.elt[splt]):
                    temp_root = back_elt.right
                else:
                    temp_root = back_elt.left

            if temp_root:
                nodeList.append(temp_root)
                curDist = computeDist(query,temp_root.elt)
                if min_dist > curDist:
                    min_dist = curDist
                    NN = temp_root.elt
    return NN, min_dist

def computeDist(pt1, pt2):
    '''
    计算两个数据点的距离
    return:pt1和pt2之间的距离
    '''
    sum = 0.0
    for i in range(len(pt1)):
        sum = sum + (pt1[i]-pt2[i]) * (pt1[i]-pt2[i])
    return math.sqrt(sum)       

你可能感兴趣的:(机器学习,算法)