《统计学习方法》学习之二:kd树

一、kd树模型
  在使用k-means等算法时,经常需要查找最近邻节点,kd树就是一种二叉树,将特征空间进行分割,以便减小搜索时间。(具体内容可以参考李航《统计学习方法》一书)。

二、代码实现
  这里实现二维平面上的kd树,可以类推到n维特征空间。
(本人代码水平有限,如有错误,还请各位大牛不吝指出)

import math


class kdTreeNode(object):
    '''
    kd树的节点数据结构。
    '''
    def __init__(self, value, x_range, y_range):
        '''
        :param value:kd树的节点数据,为一个长度为2的坐标信息,分别表示x坐标和y坐标。
        :param x_range: 当前节点数据制约的x坐标的范围。
        :param y_range: 当前节点数据制约的y坐标的范围。

        self.left:左子节点。
        self.right:右子节点。
        self.father:父节点。
        '''
        self.value = value
        self.x_range = x_range
        self.y_range = y_range
        self.left = None
        self.right = None
        self.father = None


class kdTree(object):
    '''
    kd树。
    '''
    def __init__(self, data, axis=0, xrange=[-100, 100], yrange=[-100, 100]):
        '''
        :param data:需要建树的坐标信息,是一个列表 。
        :param axis: 指定开始建树时的坐标轴,0或者1。
        :param xrange: x的范围。
        :param yrange: y的范围。
        '''
        assert axis in [0, 1]
        self.__data = data
        self.__axis = axis
        self.__xrange = xrange
        self.__yrange = yrange
        self.root = self.__build(self.__data, self.__axis, self.__xrange, self.__yrange)

    def __build(self, data, axis, xrange, yrange):
        '''
        递归建立kd平衡树。
        :param data: 需要建树的坐标信息,是一个列表 。
        :param axis: 坐标轴,0或者1。
        :param xrange: 节点制约的x坐标范围。
        :param yrange: 节点制约的y坐标范围。
        :return: kd树的节点,kdTreeNode
        '''
        assert axis in [0, 1]
        if len(data) == 0:
            return kdTreeNode(None, xrange, yrange)

        data.sort(key=lambda s: s[axis])
        pos = len(data) // 2

        assert xrange[0] <= data[pos][0] <= xrange[1] and yrange[0] <= data[pos][1] <= yrange[1]

        myroot = kdTreeNode(data[pos], xrange, yrange)

        if axis == 0:
            myroot.left = self.__build(data[0:pos:], 1-axis, [xrange[0], data[pos][0]], yrange)
            myroot.right = self.__build(data[pos+1::], 1-axis, [data[pos][0], xrange[1]], yrange)
        else:
            myroot.left = self.__build(data[0:pos:], 1 - axis, xrange, [yrange[0], data[pos][1]])
            myroot.right = self.__build(data[pos + 1::], 1 - axis, xrange, [data[pos][1], yrange[1]])

        myroot.left.father = myroot
        myroot.right.father = myroot

        return myroot

    def insert(self, data):
        '''
        由于主要实现的是kd树,因此为了省去不必要的麻烦,每次插入一个新节点,kd树都需要重新建立。
        可以采用插入节点并调整的方法来重新恢复平衡树。
        :param data: 需要插入的节点坐标。
        :return: None
        '''
        self.__data.append(data)
        self.root = self.__build(self.__data, self.__axis, self.__xrange, self.__yrange)

    def search_nearest_neighbor(self, point):
        '''
        最近邻搜索
        :param point:需要被搜索最近邻的坐标,是一个长度为2的list。
        :return: 最近距离和最近邻节点的坐标。

        算法采用队列保存需要搜索的节点。
        '''

        queue = []
        node = self.root
        '''搜索包含该节点的叶子节点'''
        while node.value is not None:
            queue.append(node)
            if self.__is_containing(point, xrange=node.left.x_range, yrange=node.left.y_range):
                node = node.left
            else:
                node = node.right
        '''
        由于叶子节点value为None,因此mindis(最小距离)需要计算point和该叶子节点的父节点之间的距离,
        同时,最近节点为该叶子节点的父节点。
        '''
        mindis = math.sqrt((point[0] - queue[-1].value[0])**2 + (point[1] - queue[-1].value[1])**2)
        nearest = queue[-1]
        queue.append(node)

        while len(queue) != 0:
            node = queue.pop(0)
            if node.value is not None:
                '''分别对node的左右两个子节点进行搜索'''
                for n in [node.left, node.right]:
                    '''
                    如果point到n的距离小于或者等于mindis,则说明n可能包含比当前最近点更近的点,因此加入队列,反之则什么都不做。
                    '''
                    dis_block = self.__calculate_distence(point, n.x_range, n.y_range)
                    if dis_block > mindis:
                        pass
                    else:
                        if dis_block not in queue:
                            queue.append(n)
                '''
                计算当前节点node和point之间的距离,如果距离小于mindis,则更新最小距离和最近节点。
                '''
                dis = math.sqrt((point[0] - node.value[0])**2 + (point[1] - node.value[1])**2)
                if dis < mindis:
                    mindis = dis
                    nearest = node

        return mindis, nearest.value
        pass

    def __is_containing(self, point, xrange, yrange):
        '''
        判断point是否包含在xrange和yrange组成的区域中。
        如果point在xrange和yrange做成的区域中,包括边界,则返回True,反之则返回False。
        :param point:点坐标,长度为2的list。
        :param xrange: x的范围,长度为2。
        :param yrange: y的范围,长度为2。
        :return: bool,如果point在xrange和yrange做成的区域中,包括边界,则返回True,反之则返回False。
        '''
        return xrange[0] <= point[0] <= xrange[1] and yrange[0] <= point[1] <= yrange[1]

    def __calculate_distence(self, point, xrange, yrange):
        '''
        计算point到xrange和yrange组成区域的距离。
        :param point:点坐标,长度为2的list。
        :param xrange: x的范围,长度为2。
        :param yrange: y的范围,长度为2。
        :return: float,point到xrange和yrange组成区域的距离。
        '''
        if xrange[0] <= point[0] <= xrange[1] and yrange[0] <= point[1] <= yrange[1]:
            return 0
        if point[0] < xrange[0]:
            if point[1] < yrange[0]:
                return math.sqrt((point[0] - xrange[0])**2 + (point[1] - yrange[0])**2)
            elif yrange[0] <= point[1] <= yrange[1]:
                return abs(point[0] - xrange[0])
            else:
                return math.sqrt((point[0] - xrange[0])**2 + (point[1] - yrange[1])**2)
        elif xrange[0] <= point[0] <= xrange[1]:
            if point[1] < yrange[0]:
                return abs(point[1] - yrange[0])
            elif yrange[0] <= point[1] <= yrange[1]:
                return 0
            else:
                return abs(point[1] - yrange[1])
        else:
            if point[1] < yrange[0]:
                return math.sqrt((point[0] - xrange[1])**2 + (point[1] - yrange[0])**2)
            elif yrange[0] <= point[1] <= yrange[1]:
                return abs(point[0] - xrange[1])
            else:
                return math.sqrt((point[0] - xrange[1])**2 + (point[1] - yrange[1])**2)
        pass


if __name__ == '__main__':

    point = []
    import random
    '''随机生成坐标点'''
    for i in range(1000):
        p = [random.uniform(0, 100), random.uniform(0, 100)]
        point.append(p)

    root = kdTree(point, xrange=[0, 100], yrange=[0, 100])

    cnt = 0
    correct = 0
    x = 0.01
    y = 0.01
    while x < 100.0:
        while y < 100.0:
            mindis = 2**32-1
            position = None
            for i in point:
                dis = math.sqrt((x - i[0])**2 + (y - i[1])**2)
                if dis < mindis:
                    mindis = dis
                    position = i

            m, p = root.search_nearest_neighbor([x, y])
            cnt += 1
            if m == mindis and position == p:
                correct += 1
            y += 0.01
        x += 0.01

    print('一共检测节点数目:', cnt)
    print('正确找到最近邻节点数目:', correct)

    import time

    begin = time.clock()
    while x < 100.0:
        while y < 100.0:
            mindis = 2**32-1
            position = None
            for i in point:
                dis = math.sqrt((x - i[0])**2 + (y - i[1])**2)
                if dis < mindis:
                    mindis = dis
                    position = i

            y += 0.01
        x += 0.01

    end = time.clock()
    print('线性查找最近邻所用时间:', end - begin)

    begin = time.clock()
    while x < 100.0:
        while y < 100.0:

            m, p = root.search_nearest_neighbor([x, y])

            y += 0.01
        x += 0.01

    end = time.clock()
    print('kd树查找最近邻所用时间:', end - begin)


结果如下:

一共检测节点数目: 9999
正确找到最近邻节点数目: 9999
线性查找最近邻所用时间: 7.894781509948016e-07
kd树查找最近邻所用时间: 3.9473907549739996e-07

  此处样本数目不大,因此在查找所用的时间上相差不大,但是依然可以看出kd树的时间小于线性查找。

三、结论
  kd树是一种二叉树的应用,可以减小k-means等算法中查找最近邻所需要的时间,这对于大容量的样本数据有十分积极的意义。

你可能感兴趣的:(机器学习,统计学习方法,kd树)