一、kd树模型
在使用k-means等算法时,经常需要查找最近邻节点,kd树就是一种二叉树,将特征空间进行分割,以便减小搜索时间。(具体内容可以参考李航《统计学习方法》一书)。
二、代码实现
这里实现二维平面上的kd树,可以类推到n维特征空间。
(本人代码水平有限,如有错误,还请各位大牛不吝指出)
import math
class kdTreeNode(object):
'''
kd树的节点数据结构。
'''
def __init__(self, value, x_range, y_range):
'''
:param value:kd树的节点数据,为一个长度为2的坐标信息,分别表示x坐标和y坐标。
:param x_range: 当前节点数据制约的x坐标的范围。
:param y_range: 当前节点数据制约的y坐标的范围。
self.left:左子节点。
self.right:右子节点。
self.father:父节点。
'''
self.value = value
self.x_range = x_range
self.y_range = y_range
self.left = None
self.right = None
self.father = None
class kdTree(object):
'''
kd树。
'''
def __init__(self, data, axis=0, xrange=[-100, 100], yrange=[-100, 100]):
'''
:param data:需要建树的坐标信息,是一个列表 。
:param axis: 指定开始建树时的坐标轴,0或者1。
:param xrange: x的范围。
:param yrange: y的范围。
'''
assert axis in [0, 1]
self.__data = data
self.__axis = axis
self.__xrange = xrange
self.__yrange = yrange
self.root = self.__build(self.__data, self.__axis, self.__xrange, self.__yrange)
def __build(self, data, axis, xrange, yrange):
'''
递归建立kd平衡树。
:param data: 需要建树的坐标信息,是一个列表 。
:param axis: 坐标轴,0或者1。
:param xrange: 节点制约的x坐标范围。
:param yrange: 节点制约的y坐标范围。
:return: kd树的节点,kdTreeNode
'''
assert axis in [0, 1]
if len(data) == 0:
return kdTreeNode(None, xrange, yrange)
data.sort(key=lambda s: s[axis])
pos = len(data) // 2
assert xrange[0] <= data[pos][0] <= xrange[1] and yrange[0] <= data[pos][1] <= yrange[1]
myroot = kdTreeNode(data[pos], xrange, yrange)
if axis == 0:
myroot.left = self.__build(data[0:pos:], 1-axis, [xrange[0], data[pos][0]], yrange)
myroot.right = self.__build(data[pos+1::], 1-axis, [data[pos][0], xrange[1]], yrange)
else:
myroot.left = self.__build(data[0:pos:], 1 - axis, xrange, [yrange[0], data[pos][1]])
myroot.right = self.__build(data[pos + 1::], 1 - axis, xrange, [data[pos][1], yrange[1]])
myroot.left.father = myroot
myroot.right.father = myroot
return myroot
def insert(self, data):
'''
由于主要实现的是kd树,因此为了省去不必要的麻烦,每次插入一个新节点,kd树都需要重新建立。
可以采用插入节点并调整的方法来重新恢复平衡树。
:param data: 需要插入的节点坐标。
:return: None
'''
self.__data.append(data)
self.root = self.__build(self.__data, self.__axis, self.__xrange, self.__yrange)
def search_nearest_neighbor(self, point):
'''
最近邻搜索
:param point:需要被搜索最近邻的坐标,是一个长度为2的list。
:return: 最近距离和最近邻节点的坐标。
算法采用队列保存需要搜索的节点。
'''
queue = []
node = self.root
'''搜索包含该节点的叶子节点'''
while node.value is not None:
queue.append(node)
if self.__is_containing(point, xrange=node.left.x_range, yrange=node.left.y_range):
node = node.left
else:
node = node.right
'''
由于叶子节点value为None,因此mindis(最小距离)需要计算point和该叶子节点的父节点之间的距离,
同时,最近节点为该叶子节点的父节点。
'''
mindis = math.sqrt((point[0] - queue[-1].value[0])**2 + (point[1] - queue[-1].value[1])**2)
nearest = queue[-1]
queue.append(node)
while len(queue) != 0:
node = queue.pop(0)
if node.value is not None:
'''分别对node的左右两个子节点进行搜索'''
for n in [node.left, node.right]:
'''
如果point到n的距离小于或者等于mindis,则说明n可能包含比当前最近点更近的点,因此加入队列,反之则什么都不做。
'''
dis_block = self.__calculate_distence(point, n.x_range, n.y_range)
if dis_block > mindis:
pass
else:
if dis_block not in queue:
queue.append(n)
'''
计算当前节点node和point之间的距离,如果距离小于mindis,则更新最小距离和最近节点。
'''
dis = math.sqrt((point[0] - node.value[0])**2 + (point[1] - node.value[1])**2)
if dis < mindis:
mindis = dis
nearest = node
return mindis, nearest.value
pass
def __is_containing(self, point, xrange, yrange):
'''
判断point是否包含在xrange和yrange组成的区域中。
如果point在xrange和yrange做成的区域中,包括边界,则返回True,反之则返回False。
:param point:点坐标,长度为2的list。
:param xrange: x的范围,长度为2。
:param yrange: y的范围,长度为2。
:return: bool,如果point在xrange和yrange做成的区域中,包括边界,则返回True,反之则返回False。
'''
return xrange[0] <= point[0] <= xrange[1] and yrange[0] <= point[1] <= yrange[1]
def __calculate_distence(self, point, xrange, yrange):
'''
计算point到xrange和yrange组成区域的距离。
:param point:点坐标,长度为2的list。
:param xrange: x的范围,长度为2。
:param yrange: y的范围,长度为2。
:return: float,point到xrange和yrange组成区域的距离。
'''
if xrange[0] <= point[0] <= xrange[1] and yrange[0] <= point[1] <= yrange[1]:
return 0
if point[0] < xrange[0]:
if point[1] < yrange[0]:
return math.sqrt((point[0] - xrange[0])**2 + (point[1] - yrange[0])**2)
elif yrange[0] <= point[1] <= yrange[1]:
return abs(point[0] - xrange[0])
else:
return math.sqrt((point[0] - xrange[0])**2 + (point[1] - yrange[1])**2)
elif xrange[0] <= point[0] <= xrange[1]:
if point[1] < yrange[0]:
return abs(point[1] - yrange[0])
elif yrange[0] <= point[1] <= yrange[1]:
return 0
else:
return abs(point[1] - yrange[1])
else:
if point[1] < yrange[0]:
return math.sqrt((point[0] - xrange[1])**2 + (point[1] - yrange[0])**2)
elif yrange[0] <= point[1] <= yrange[1]:
return abs(point[0] - xrange[1])
else:
return math.sqrt((point[0] - xrange[1])**2 + (point[1] - yrange[1])**2)
pass
if __name__ == '__main__':
point = []
import random
'''随机生成坐标点'''
for i in range(1000):
p = [random.uniform(0, 100), random.uniform(0, 100)]
point.append(p)
root = kdTree(point, xrange=[0, 100], yrange=[0, 100])
cnt = 0
correct = 0
x = 0.01
y = 0.01
while x < 100.0:
while y < 100.0:
mindis = 2**32-1
position = None
for i in point:
dis = math.sqrt((x - i[0])**2 + (y - i[1])**2)
if dis < mindis:
mindis = dis
position = i
m, p = root.search_nearest_neighbor([x, y])
cnt += 1
if m == mindis and position == p:
correct += 1
y += 0.01
x += 0.01
print('一共检测节点数目:', cnt)
print('正确找到最近邻节点数目:', correct)
import time
begin = time.clock()
while x < 100.0:
while y < 100.0:
mindis = 2**32-1
position = None
for i in point:
dis = math.sqrt((x - i[0])**2 + (y - i[1])**2)
if dis < mindis:
mindis = dis
position = i
y += 0.01
x += 0.01
end = time.clock()
print('线性查找最近邻所用时间:', end - begin)
begin = time.clock()
while x < 100.0:
while y < 100.0:
m, p = root.search_nearest_neighbor([x, y])
y += 0.01
x += 0.01
end = time.clock()
print('kd树查找最近邻所用时间:', end - begin)
结果如下:
一共检测节点数目: 9999
正确找到最近邻节点数目: 9999
线性查找最近邻所用时间: 7.894781509948016e-07
kd树查找最近邻所用时间: 3.9473907549739996e-07
此处样本数目不大,因此在查找所用的时间上相差不大,但是依然可以看出kd树的时间小于线性查找。
三、结论
kd树是一种二叉树的应用,可以减小k-means等算法中查找最近邻所需要的时间,这对于大容量的样本数据有十分积极的意义。