在寻找输入样本的k个近邻的时候,若进行线性扫描,对于大数据集来说耗时太久,为了加快搜索速度,提出了用kd树实现k个近邻的搜索,此时复杂度为O(logN)。
首先是建树
这里假设输入数据一个N×K的矩阵,N代表实例点的个数,K代表样本空间的维度。每一行代表一个实例点。
每个节点包含六个属性:
- SamplePoints:实例点的行号,表示该节点对应区域包含的所有实例点
- SplitDim:切割对应的区域时选择的特征(维度)
- MidPoint:是一个元组,(切分点的行号,切分特征的中位数)
- left:指向左子节点
- right:指向右子节点
- father:指向父节点
- visited:该节点是否已被访问的标志
包含两个方法:
- get_median():获取切割特征的中位数
- get_dim():获取方差最大的特征作为切割特征
过程如下:
- 构造根节点,使根节点对应于k维空间中包含所有实例点的超矩形区域;
- 在超矩形区域上选择一个坐标轴和在一个切分点,确定一个超平面,这个超平面通过选定的切分点并垂直于选定的坐标轴,将当前超矩形区域切分为左右两个子区域(子节点);这时,实例被分到两个子区域。
- 将切分点保存在根节点上。
- 重复步骤2、3,直到子区域内只含有不含实例时终止。
import numpy as np
from collections import Counter
class KdTreeNode(object):
def __init__(self, SamplePoints):
self.SamplePoints = SamplePoints
self.SplitDim = self.get_dim()
self.MidPoint = self.get_MidPoint()
self.left = None
self.right = None
self.father = None
self.visited = False
def get_dim(self):
variance = np.var(X[self.SamplePoints, :], axis = 0) #计算该节点包含的实例点每个特征的方差
#print(variance)
return np.argmax(variance) #选择方差最大的特征
def get_MidPoint(self):
tmp = X[self.SamplePoints, self.SplitDim]
length = len(tmp)
index = np.argsort(tmp) #该函数返回的是数组值从小到大的索引值
return (self.SamplePoints[index[int(length/2)]], tmp[index[int(length/2)]]) #(中位数所在的行号,中位数的值)
def build_tree(SamplePoints, father = None): #构建kd树
if len(SamplePoints) == 0: #子区域不含实例点时停止
return None
root = KdTreeNode(SamplePoints)
LeftPoints = [] #分割区域依据的特征小于或等于median的实例点
RightPoints = [] #分割区域依据的特征大于median的实例点
for x in SamplePoints:
if x == root.MidPoint[0]:
continue
if X[x, root.SplitDim] <= root.MidPoint[1]:
LeftPoints.append(x)
else:
RightPoints.append(x)
root.father = father
if len(SamplePoints) > 1: #子区域只含一个点时停止
root.left = build_tree(LeftPoints, root) #构建左子树
root.right = build_tree(RightPoints, root) #构建右子树
return root
最近邻搜索
- 从根节点出发,递归地向下访问kd树。若目标点x当前维(即切割根节点对应区域时选择的维度)的坐标小于或等于切分点的坐标,则移动到左子节点,否则移动到右子节点。直到子节点为叶节点为止,记此叶节点为L。
- 以此叶节点L上的切分点为“当前最近点Ncur”,记录Ncur与目标点的距离为Dcur。
- 判断L的父节点是否已被访问。
3.1. 若未被访问,检查L的父节点的另一子节点(即L的兄弟节点)对应的区域是否与以目标点为球心以Dcur为半径的超球体相交。具体做法是在分割L的父节点区域时选择的维度上计算目标点与切分点的坐标差值的绝对值,然后将其与Dcur比较。
a) 若大于Dur,说明不相交。则标记L的父节点已被访问,回到此步骤的开头。
b) 若小于或等于Dcur,说明相交。先计算L的父节点上的切分点与目标点的距离,检查是否要更新Pcur与Dcur,完成后标记L的父节点已被访问。从L的兄弟节点出发,按照步骤1找到一个新的叶节点L。计算L上的切分点与目标点的距离,检查是否要更新Pcur与Dcur,完成后回到此步骤的开头。
3.2 若已被访问,判断L的父节点是否为根节点。
a) 若是,则停止整个程序。Pcur即为目标点的最近邻。
b) 若不是,则回退到L的父节点,具做法为令L=L的父节点,然后回到此步骤的开头。
def approx_nearest_neighbor(root, TargetPoint): #寻找树中与目标点的近似最近邻点,该最似最近邻仅仅是与目标点在同一分区中,不一定是最近邻
if root.left == None and root.right == None:
return root
if TargetPoint[root.SplitDim] <= root.MidPoint[1]:
if root.left == None: #若应往左子树走时发现左子树为空,转向右子树搜寻,保证最后返回的是一个叶节点
return approx_nearest_neighbor(root.right, TargetPoint)
return approx_nearest_neighbor(root.left, TargetPoint)
else:
if root.right == None: #若应往右子树走时发现左子树为空,转向左子树搜寻
return approx_nearest_neighbor(root.left, TargetPoint)
return approx_nearest_neighbor(root.right, TargetPoint)
def nearest_neighbor_search(root, TargetPoint): #搜索与目标点的欧氏距离最小的样本点
Vis = approx_nearest_neighbor(root, TargetPoint) #表示以该节点为根节点的子树已被搜索完成
Ncur = X[Vis.MidPoint[0], :]#开始时直接用近似最近邻点作为当前最近邻点
Dcur = np.sqrt(np.sum(np.square(Ncur - TargetPoint))) #目标点与当前最近邻的欧式距离
if Vis == root: #当样本空间中只有一个点则直接输出该点,注意Vis是一个节点,Ncur是一个点向量
return (Ncur, Dcur)
while True:
if not Vis.father.visited: #若Vis的父节点未被访问
VerticalDis = abs(TargetPoint[Vis.father.SplitDim] - Vis.father.MidPoint[1]) #目标点到以Vis父节点为切分点的分割超平面的垂直距离
#若Vis的兄弟节点代表的区域与以目标点为圆心Dcur为半径的圆相交
if VerticalDis <= Dcur:
EuclideanDis = np.sqrt(np.sum(np.square(X[Vis.father.MidPoint[0], :] - TargetPoint))) #Vis的父节点与目标点的距离
if EuclideanDis < Dcur: #若比Dcur小,则将其作为当前最近邻
Dcur = EuclideanDis
Ncur = X[Vis.father.MidPoint[0], :]
Vis.father.visited = True #此节点已被访问
#寻找Vis的兄弟节点
if Vis.father.left == Vis:
brother = Vis.father.right
else:
brother = Vis.father.left
#若无兄弟节点,直接爬升到Vis的父节点
if brother == None:
continue
#若有兄弟节点
Vis = approx_nearest_neighbor(brother, TargetPoint)
EuclideanDis = np.sqrt(np.sum(np.square(X[Vis.MidPoint[0], :] - TargetPoint)))
if EuclideanDis < Dcur:
Dcur = EuclideanDis
Ncur = X[Vis.MidPoint[0], :]
continue
#若不相交
else:
Vis.father.visited = True
else: #若Vis的父节点已被访问
if Vis.father == root: #若根节点已被访问,则结束搜索
break
else:
Vis = Vis.father #向上爬升到Vis的父节点
return (Ncur, Dcur)
K近邻搜索
k近邻的搜索与最近邻搜索类似,不过程序中的“当前最近邻Ncur”要改为“当前K近邻Kcur”,它是一个二维列表,里面的每一行代表了K个近邻点中的一个。在每次比较一个新的节点时,都需判断是否要对它进行更新,用离目标点更近的点代替更远的点。
def compare_dis(CurrentPoint, TargetPoint, Ncur, K): #计算样本点与目标点的距离,若有必要的话对Ncur进行更新
EuclideanDis = np.sqrt(np.sum(np.square(CurrentPoint - TargetPoint))) #计算欧式距离
Ncur = sorted(Ncur, key = lambda x : -x[1]) #对Ncur中的K个点按照到目标点的距离从远到近排序
if EuclideanDis < Ncur[0][1]: #如果当前目标点到目标点的距离比Ncur中最远的点要近,则对Ncur进行更新
Ncur = Ncur[1:K]
Ncur.append((CurrentPoint, EuclideanDis))
return Ncur
def k_neighbor_search(root, TargetPoint, K): #搜索与目标点的欧氏距离最小的K个样本点
Vis = approx_nearest_neighbor(root, TargetPoint) #Vis表示以该节点为根节点的子树已被搜索完成
Ncur = [] #存储当前K个近邻点
for i in range(K): #用K个离目标点无穷远的点作为Ncur的初始值
Ncur.append((X[i,:], float('inf')))
Ncur = compare_dis(X[Vis.MidPoint[0], :], TargetPoint, Ncur, K)
if Vis == root: #当K=1时, 若样本空间中只有一个点,则直接输出该点
return Ncur
while True:
if not Vis.father.visited: #若Vis的父节点未被访问
VerticalDis = abs(TargetPoint[Vis.father.SplitDim] - Vis.father.MidPoint[1]) #目标点到以Vis父节点为切分点的分割超平面的垂直距离
#若Vis的兄弟节点代表的区域与以目标点为圆心Dcur为半径的圆相交
if VerticalDis <= sorted(Ncur, key = lambda x : -x[1])[0][1]:
Ncur = compare_dis(X[Vis.father.MidPoint[0], :], TargetPoint, Ncur, K) #判断Vis的父节点是否要加入到Ncur中
Vis.father.visited = True #此节点已被访问
brother = Vis.father.right if Vis.father.left == Vis else Vis.father.left #寻找Vis的兄弟节点
#若无兄弟节点,直接爬升到Vis的父节点
if brother == None:
continue
#若有兄弟节点
Vis = approx_nearest_neighbor(brother, TargetPoint)
Ncur = compare_dis(X[Vis.MidPoint[0], :], TargetPoint, Ncur, K)
continue
#若不相交
else:
Vis.father.visited = True
else: #若Vis的父节点已被访问
if Vis.father == root: #若根节点已被访问,则结束搜索
break
else:
Vis = Vis.father #向上爬升到Vis的父节点
return Ncur
测试程序
下图中的红色叉叉代表目标点。
#主程序
X = np.array([[2,3],
[5,4],
[9,6],
[4,7],
[8,1],
[7,2]]) #存储样本向量
TargetPoint = np.array([8, 0]) #输入目标点
root = build_tree(range(len(X))) #建树
while True:
K = int(input('Input K:').strip()) #若样本点的个数没有K个,需重新设定K
if len(X) < K:
print('Retry')
continue
break
Ncur = k_neighbor_search(root, TargetPoint, K)
for point in Ncur:
print(point[0]) #输出K个近邻点的坐标