KNN的核心算法kd-tree和ball-tree


目录
[toc]


1. K近邻法基础

1.1 模型与算法

K近邻法(K-nearest neighbor,KNN)是最基础的机器学习模型之一,它的类别为:

  • 分类(√)回归(√)、标注
  • 概率软分类、非概率硬分类(√)
  • 监督(√)、无监督、强化
  • 线性、非线性(√)
  • 判别(√)、生成

KNN既可以用于分类,也可用于回归。分类模型和回归模型本质一样,分类模型是将回归模型的输出离散化。一般来讲:回归问题是对真实值的定量逼近预测,通常结果为连续值;分类问题是为对象定性打标签,通常结果为离散值。

分类模型:
输入:
1.训练数据集:,其中,为训练样本,为样本的类别。
2.测试数据
输出:
测试数据所属的类别
算法:
1.根据给定的距离度量,在训练集中寻找与x最临近的k个点,涵盖这k个点的区域记作;
2.根据多数表决规则,确定x的类别y:

式中,为指示函数,即当时,否则为0.

回归模型:
输入:
1.训练数据集:,其中,为训练样本,为样本对应的值。
2.待回归数据
输出 :
对应的值
算法:
1.根据给定的距离度量,在训练集中寻找与x最临近的k个点,涵盖这k个点的区域记作;
2.根据这k个近邻点的对应的值,确定x的类别y:

1.2 距离度量

在上节的算法中提到了距离度量,最常用的距离度量方法是欧式距离,即二范数距离:

也可以是1范数距离,又叫曼哈顿距离:

曼哈顿距离可用于这样的场景:在一个由垂直和水平街道分割的城市里,从一个交叉路口到另一个交叉路口之间的路程即为曼哈顿距离。下图中,绿色连线的长度为欧式距离,其他三种颜色的连线长度都为曼哈顿距离。

曼哈顿距离.png

还可以时∞范数距离,等价于各维坐标距离的最大值:

负无穷范数刚好相反,等价于各维坐标距离的最小值。
范数距离的关系如下图所示:
Lp范数距离.png

1.3 K值选择

K值选择会影响算法结果。
若选择较小的K值,相当于用较小的邻域中的训练样本来预测,可以获得较小的经验误差,但容易过拟合,泛化误差将会很大,泛化能力弱。
若选择较大的K值,能起到平滑的效果,随着K的增大,泛化误差先减小,再增大。而经验误差随着K增大而不断增大。
如果K=N,无论输入实例是什么,都简单地预测为训练实例中的最多数(分类),或训练实例的均值(回归)。
在实际应用中,K一般取一个较小的值,且通常采用交叉验证的方法来选取最优的K。
下图测试了回归问题中,K的不同取值对于回归性能的影响,具体代码见附录:


K值选择对回归性能的影响

1.4 邻近点的搜索算法

KNN算法需要在中搜索与x最临近的k个点,最直接的方法是逐个计算x与中所有点的距离,并排序选择最小的k个点,即线性扫描。当训练数据集很大时,计算非常耗时,以至于不可行。
实际应用中常用的是kd-tree(k-dimension tree)和ball-tree这两种方法。ball-tree是对kd-tree的改进,在数据维度大于20时,kd-tree性能急剧下降,而ball-tree在高维数据情况下具有更好的性能。
关于kd-tree和ball-tree将在本文第2和第3章介绍。

2. kd-tree算法

KNN算法的核心是寻找待测样本在训练样本集中的k个近邻,如果训练样本集过大,则传统的遍历全样本寻找k近邻的方式将导致性能的急剧下降。
kd-tree以空间换时间,利用训练样本集中的样本点,沿各维度依次对k维空间进行划分,建立二叉树,利用分治思想大大提高算法搜索效率。我们知道,二分查找的算法复杂度是,kd-tree的搜索效率与之接近(取决于所构造kd-tree是否接近平衡树)。如下图所示,为训练样本对空间的划分以及对应的kd树。绿色实心五角星为测试样本,通过kd-tree的搜索算法,快速找到与其最近邻的3个训练样本点(空心五角星标注的点)。

k近邻:kd-tree

2.1 kd-tree构建方法

构造kd-tree的方法如下:构造根节点,使根节点对应包含所有训练样本点的k维超矩形区域;递归构建左右子节点,对当前节点所包含的样本点进行划分,划分是根据第i维的中位点来确定的,中位点赋值给当前节点作为第i维的划分点,第i维小于该点的,划给左儿子节点,大于该点的,划给右儿子节点。根节点对应的划分维度为0,后继子节点按照深度依次加1,即。
这种通过对各维依次进行划分所构建的kd-tree搜索效率并非最高,若在选择划分维度时,选择剩余维度中方差最大的维度来进行划分,这样的划分分辨率最大,搜索效率也更高。但在通常的算法实现中,通过逐维度进行划分,已经足够满足性能要求。
构建kd-tree的算法伪代码如下,具体代码见附录4.2:

function fit_kd_tree is
    input: 
        x,y: 训练样本集和对应标签
        dim: 当前节点的分割维度(子节点的分割维度=(dim+1)%样本的维度)
    output: 
        node: 构造好的kd tree的根节点
    if 只有一个数据点 then
        创建一个叶子结点node包含这一单一的点:
        node.point := x[0]
        node.label := y[0]
        node.son1 := None,
        node.son2 := None
        return node
    else:
        让p为dim维度的中位点(对x中的数据按dim维排序,取中位点,偶数个则取较小的那个)
        让xl为左集合(dim维小于p点的所有点)
        让xr为右集合(dim维大于p点的所有点)
        对应的标签也划分为yl,yr
        创建带有两个孩子的node:
            node.point := p
            node.label := p的标签
            node.son1 := fit_kd_tree(xl,yl),
            node.son2 := fit_kd_tree(xr,yr)
        return node
    end if
end function

2.2 kd-tree K近邻搜索方法

搜索算法伪代码如下,具体代码见附录4.2:

function kd_tree_search is
    global:
        Q, 缓存k个最近邻点(初始时包含一个无穷远点)
        q, 与Q对应,保存Q中各点与测试点的距离
    input: 
        k, 寻找k个最近邻
        t, 测试点
        node, 当前节点
        dim, 当前节点的分割维度(子节点的分割维度=(dim+1)%数据点的维度)
    output: 
        无
    if distance(t, node.point) < max(q) then
        将node.point添加到Q,并同步更新q
        若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
    end if
    测试点到Q中最远点的距离为max(q),
    判断测试点沿dim方向-+max(q)区间是否与当前节点分割的两个子区间重合,
    若-重合,则递归搜索左儿子
    若+重合,则递归搜索右儿子
    if t[dim]-max(q) < node.point[dim]:
      kd_tree_search(k,t,node.son1)
    end if
    if t[dim]+max(q) > node.point[dim]:
      kd_tree_search(k,t,node.son2)
    end if
end function

3. ball-tree算法

在kd-tree 中,我们看到一个导致性能下降的最核心因素是因为kd-tree中被分割的子空间是一个个的超方体,求最近邻时使用的是欧式距离(超球)。超方体与超球体相交的可能性是极高的,如下图所示,凡是相交的子空间,都需要进行检查,大大的降低运行效率。


超方体与超球体相交可能性大

如果划分区域也是超球体,则相交的概率大大降低。如下图所示,为ball-tree通过超球体划分空间,去掉棱角,划分超球体和搜索超球体相交的概率大大降低,特别实在数据维度很高时,算法效率得到大大提升。


k近邻:ball-tree

3.1 ball-tree构建方法

构建ball-tree的算法伪代码如下,具体代码见附录4.3:

function fit_ball_tree is
    input: x,y, 数据点的数组和对应标签
    output: node,构造好的ball tree的根节点
    
    if 只有一个数据点 then
        创建一个叶子结点node包含这一单一的点:
            node.pivot := x[0]
            node.label := y[0]
            node.son1 := None,
            node.son2 := None,
            node.radius := 0
        return node
    else:
        让c为最宽的维度
        让p1,p2为该维度最两端的点
        让p为这个维度的中心点 := (p1+p2)/2
        让radius为p到x上最远点的距离
        让xl为左集合(距离p1更近的所有点)
        让xr为右集合(距离p2更近的所有点)
        对应的标签也划分为yl,yr
        创建带有两个孩子的node:
            node.pivot := p
            node.label := None
            node.son1 := fit_balltree(xl,yl),
            node.son2 := fit_balltree(xr,yr),
            node.radius := radius
        return node
    end if
end function

3.2 ball-tree K近邻搜索方法

搜索算法伪代码如下,具体代码见附录4.3:

function ball_tree_search is
    global:
        Q, 缓存k个最近邻点(初始时包含一个无穷远点)
        q, 与Q对应,保存Q中各点与测试点的距离
    input: 
        k, 寻找k个最近邻
        t, 测试点
        node, 当前节点
    output: 
        无
    三角不等式:若测试点到当前球的最近距离大于到Q中最远点的距离,则当前球中不可能包含待搜索的近邻点
    if distance(t, node.pivot) - node.radius ≥ max(q) then
        return
    if node为叶节点 then
        将node.pivot添加到Q,并同步更新q
        若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
    else:
        递归搜索当前节点的左儿子和右儿子
        ball_tree_search(k,t,node.son1)
        ball_tree_search(k,t,node.son2)
    end if
end function

4. 附录

4.1 K值选择对回归性能的影响

import numpy as np
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt

X,Y=make_moons(200,noise=0.05,random_state=1)
x1=np.arange(-1,2,0.1)
fig=plt.figure(figsize=(9,6))
K=[1,5,10,50,100,200]
for j in range(6):
    ax=fig.add_subplot(2,3,j+1)
    ax.scatter(X[:,0],X[:,1],s=5)

    x2=np.array([])
    k=K[j]
    for i in x1:
        x2=np.append(x2,np.mean(X[np.argsort(np.abs(X[:,0]-i))[0:k],1]))
        
    ax.plot(x1,x2,c='r')
    ax.title.set_text('k=%d'%k)

4.2 kd-tree构建和搜索

  • 注:kd-tree和ball-tree构建后,借助于networkx工具包绘制树形图。networkx工具包主要用于构建图模型和绘制图,绘制树图需要对节点位置进行调整,这里使用了hierarchy_pos_ugly和hierarchy_pos_beautiful两个函数来对图中节点按树形布局。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import networkx as nx
import random

def hierarchy_pos_ugly(G, root, levels=None, width=1., height=1.):
    """If there is a cycle that is reachable from root, then this will see infinite recursion.
       G: the graph
       root: the root node
       levels: a dictionary
               key: level number (starting from 0)
               value: number of nodes in this level
       width: horizontal space allocated for drawing
       height: vertical space allocated for drawing"""
    TOTAL = "total"
    CURRENT = "current"

    def make_levels(levels, node=root, currentLevel=0, parent=None):
        """Compute the number of nodes for each level
        """
        if not currentLevel in levels:
            levels[currentLevel] = {TOTAL: 0, CURRENT: 0}
        levels[currentLevel][TOTAL] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                levels = make_levels(levels, neighbor, currentLevel + 1, node)
        return levels

    def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
        dx = 1 / levels[currentLevel][TOTAL]
        left = dx / 2
        pos[node] = ((left + dx * levels[currentLevel][CURRENT]) * width, vert_loc)
        levels[currentLevel][CURRENT] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc - vert_gap)
        return pos

    if levels is None:
        levels = make_levels({})
    else:
        levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels}
    vert_gap = height / (max([l for l in levels]) + 1)
    return make_pos({})

def hierarchy_pos_beautiful(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
    '''
    From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
    Licensed under Creative Commons Attribution-Share Alike

    If the graph is a tree this will return the positions to plot this in a
    hierarchical layout.

    G: the graph (must be a tree)

    root: the root node of current branch
    - if the tree is directed and this is not given,
      the root will be found and used
    - if the tree is directed and this is given, then
      the positions will be just for the descendants of this node.
    - if the tree is undirected and not given,
      then a random choice will be used.

    width: horizontal space allocated for this branch - avoids overlap with other branches

    vert_gap: gap between levels of hierarchy

    vert_loc: vertical location of root

    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  # allows back compatibility with nx version 1.11
        else:
            root = random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''

        if pos is None:
            pos = {root: (xcenter, vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)
        if len(children) != 0:
            dx = width / len(children)
            nextx = xcenter - width / 2 - dx / 2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap,
                                     vert_loc=vert_loc - vert_gap, xcenter=nextx,
                                     pos=pos, parent=root)
        return pos

    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

'''
调用绘制树形图:
pos = hierarchy_pos_beautiful(G, "Root")    # 生成树的节点位置信息,第二个参数为根节点名
node_labels = nx.get_node_attributes(G, 'attr')    # 提取树的属性标签,第二个参数为属性标签名
nx.draw(G, pos, with_labels=True, labels=node_labels)    # 绘制树
plt.show()    # 显示
'''

X,Y=make_blobs(n_samples=6,
               n_features=2,
               centers=2,
               cluster_std=4,
               random_state=0)

fig=plt.figure(figsize=(5,5))
ax=fig.add_subplot(111)
plt.show()
ax.scatter(X[:,0],X[:,1],c=Y, s=60, cmap='rainbow')

# function fit_kd_tree is
#   input: 
#         x,y: 数据点的数组和对应标签
#         dim: 当前节点的分割维度(子节点的分割维度=(dim+1)%数据点的维度)
#   output: 
#         node: 构造好的kd tree的根节点

#   if 只有一个数据点 then
#       创建一个叶子结点node包含这一单一的点:
#           node.point := x[0]
#           node.label := y[0]
#           node.son1 := None,
#           node.son2 := None
#       return node
#   else:
#       让p为dim维度的中位点(对x中的数据按dim维排序,取中位点,偶数个则取较小的那个)
#       让xl为左集合(dim维小于p点的所有点)
#       让xr为右集合(dim维大于p点的所有点)
#       对应的标签也划分为yl,yr
#       创建带有两个孩子的node:
#           node.point := p
#           node.label := p的标签
#           node.son1 := fit_kd_tree(xl,yl),
#           node.son2 := fit_kd_tree(xr,yr)
#       return node
#   end if
# end function

G=nx.Graph()
def fit_kd_tree(x,y,dim=0):
    if x.size==0:
        return None
    # if x.shape[0]==1:
    #     node=dict({'point':x[0],
    #                'label':y[0],
    #                'son1':None,
    #                'son2':None
    #                })
    #     return node
    idxs=np.argsort(x[:,dim])
    middle_idx=idxs[int(idxs.size/2)]
    p=x[middle_idx] #p为dim维度的中位点
    label=y[middle_idx]
    x1,y1,x2,y2=[],[],[],[]
    for i in idxs[0:int(idxs.size/2)]:
        x1.append(x[i])
        y1.append(y[i])
    for i in idxs[int(idxs.size/2)+1:]:
        x2.append(x[i])
        y2.append(y[i])
    x1=np.array(x1)
    y1=np.array(y1)
    x2=np.array(x2)
    y2=np.array(y2)
    
    # 递归构建左子树和右子树
    son1=fit_kd_tree(x1,y1,(dim+1)%x.shape[1])
    son2=fit_kd_tree(x2,y2,(dim+1)%x.shape[1])
    node=dict({'point':p,
               'label':label,
               'son1':son1,
               'son2':son2
                })
    if son1!=None:
        G.add_edge('(%.1f,%.1f)'%tuple(node['point']),
                   '(%.1f,%.1f)'%tuple(node['son1']['point']))
    if son2!=None:
        G.add_edge('(%.1f,%.1f)'%tuple(node['point']),
                   '(%.1f,%.1f)'%tuple(node['son2']['point']))
    return node

root=fit_kd_tree(X,Y)

# 遍历kd tree,将划分区域绘制出来
def plot_partition(node,dim=0,bound=ax.axis()): #bound为绘制划分线的边界
    # if node['son1']==None and node['son2']==None: #叶结点,返回
    #     return
    line_d=np.arange(bound[(dim+1)%2*2],bound[(dim+1)%2*2+1],0.01)
    line=np.ones((line_d.size,2))
    line[:,(dim+1)%2]=line_d
    line[:,dim]=node['point'][dim]
    plt.plot(line[:,0],line[:,1])
    if node['son1']!=None:
        bound1=list(bound)
        bound1[dim*2+1]=node['point'][dim]
        plot_partition(node['son1'],(dim+1)%2,bound1)
    if node['son2']!=None:
        bound2=list(bound)
        bound2[dim*2]=node['point'][dim]
        plot_partition(node['son2'],(dim+1)%2,bound2)

orign_bound=ax.axis()
plot_partition(root)
ax.axis(orign_bound)

fig2=plt.figure(figsize=(5,5))
pos=hierarchy_pos_ugly(G,root='(%.1f,%.1f)'%tuple(root['point']))
nx.draw(G,pos,with_labels=True,font_size=8,node_size=1500,node_shape='o',node_color='xkcd:light blue')

# function kd_tree_search is
#     global:
#         Q, 缓存k个最近邻点(初始时包含一个无穷远点)
#         q, 与Q对应,保存Q中各点与测试点的距离
#     input: 
#         k, 寻找k个最近邻
#         t, 测试点
#         node, 当前节点
#         dim, 当前节点的分割维度(子节点的分割维度=(dim+1)%数据点的维度)
#     output: 
#         无
#     if distance(t, node.point) < max(q) then
#         将node.point添加到Q,并同步更新q
#         若Q内超过k个近邻点,则移出与测试点距离最远的那个点,并同步更新q
#     end if
#     测试点到Q中最远点的距离为max(q),
#     判断测试点沿dim方向-+max(q)区间是否与当前节点分割的两个子区间重合,
#     若-重合,则递归搜索左儿子
#     若+重合,则递归搜索右儿子
#     if t[dim]-max(q) < node.point[dim]:
#       kd_tree_search(k,t,node.son1)
#     end if
#     if t[dim]+max(q) > node.point[dim]:
#       kd_tree_search(k,t,node.son2)
#     end if
# end function

Q=np.array([[np.inf,np.inf]])
q=np.array([np.inf])
def kd_tree_search(k,t,node,dim=0):
    global Q,q
    if np.linalg.norm(t-node['point'])node['point'][dim] and node['son2']!=None:
        kd_tree_search(k,t,node['son2'],(dim+1)%t.size)

k=3
t=np.array([6,3])
kd_tree_search(k,t,root)
print(Q)
fig.axes[0].scatter(t[0],t[1],marker='*',s=500,color='green')
fig.axes[0].scatter(Q[:,0],Q[:,1],marker='*',s=500,facecolors='none',edgecolors='green')

4.3 ball-tree构建和搜索

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import networkx as nx
import random

def hierarchy_pos_ugly(G, root, levels=None, width=1., height=1.):
    """If there is a cycle that is reachable from root, then this will see infinite recursion.
       G: the graph
       root: the root node
       levels: a dictionary
               key: level number (starting from 0)
               value: number of nodes in this level
       width: horizontal space allocated for drawing
       height: vertical space allocated for drawing"""
    TOTAL = "total"
    CURRENT = "current"

    def make_levels(levels, node=root, currentLevel=0, parent=None):
        """Compute the number of nodes for each level
        """
        if not currentLevel in levels:
            levels[currentLevel] = {TOTAL: 0, CURRENT: 0}
        levels[currentLevel][TOTAL] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                levels = make_levels(levels, neighbor, currentLevel + 1, node)
        return levels

    def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
        dx = 1 / levels[currentLevel][TOTAL]
        left = dx / 2
        pos[node] = ((left + dx * levels[currentLevel][CURRENT]) * width, vert_loc)
        levels[currentLevel][CURRENT] += 1
        neighbors = G.neighbors(node)
        for neighbor in neighbors:
            if not neighbor == parent:
                pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc - vert_gap)
        return pos

    if levels is None:
        levels = make_levels({})
    else:
        levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels}
    vert_gap = height / (max([l for l in levels]) + 1)
    return make_pos({})

def hierarchy_pos_beautiful(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
    '''
    From Joel's answer at https://stackoverflow.com/a/29597209/2966723.
    Licensed under Creative Commons Attribution-Share Alike

    If the graph is a tree this will return the positions to plot this in a
    hierarchical layout.

    G: the graph (must be a tree)

    root: the root node of current branch
    - if the tree is directed and this is not given,
      the root will be found and used
    - if the tree is directed and this is given, then
      the positions will be just for the descendants of this node.
    - if the tree is undirected and not given,
      then a random choice will be used.

    width: horizontal space allocated for this branch - avoids overlap with other branches

    vert_gap: gap between levels of hierarchy

    vert_loc: vertical location of root

    xcenter: horizontal location of root
    '''
    if not nx.is_tree(G):
        raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')

    if root is None:
        if isinstance(G, nx.DiGraph):
            root = next(iter(nx.topological_sort(G)))  # allows back compatibility with nx version 1.11
        else:
            root = random.choice(list(G.nodes))

    def _hierarchy_pos(G, root, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
        '''
        see hierarchy_pos docstring for most arguments

        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch. - only affects it if non-directed

        '''

        if pos is None:
            pos = {root: (xcenter, vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        children = list(G.neighbors(root))
        if not isinstance(G, nx.DiGraph) and parent is not None:
            children.remove(parent)
        if len(children) != 0:
            dx = width / len(children)
            nextx = xcenter - width / 2 - dx / 2
            for child in children:
                nextx += dx
                pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap,
                                     vert_loc=vert_loc - vert_gap, xcenter=nextx,
                                     pos=pos, parent=root)
        return pos

    return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

'''
调用绘制树形图:
pos = hierarchy_pos_beautiful(G, "Root")    # 生成树的节点位置信息,第二个参数为根节点名
node_labels = nx.get_node_attributes(G, 'attr')    # 提取树的属性标签,第二个参数为属性标签名
nx.draw(G, pos, with_labels=True, labels=node_labels)    # 绘制树
plt.show()    # 显示
'''

X,Y=make_blobs(n_samples=6,
               n_features=2,
               centers=2,
               cluster_std=4,
               random_state=0)

fig=plt.figure(figsize=(5,5))
ax=fig.add_subplot(111)
plt.show()
ax.scatter(X[:,0],X[:,1],c=Y, s=60, cmap='rainbow')

# function fit_ball_tree is
#     input: x,y, 数据点的数组和对应标签
#     output: node,构造好的ball tree的根节点
    
#     if 只有一个数据点 then
#       创建一个叶子结点node包含这一单一的点:
#         node.pivot := x[0]
#         node.label := y[0]
#         node.son1 := None,
#         node.son2 := None,
#         node.radius := 0
#         return node
#   else:
#       让c为最宽的维度
#         让p1,p2为该维度最两端的点
#       让p为这个维度的中心点 := (p1+p2)/2
#         让radius为p到x上最远点的距离
#       让xl为左集合(距离p1更近的所有点)
#         让xr为右集合(距离p2更近的所有点)
#         对应的标签也划分为yl,yr
#         创建带有两个孩子的node:
#             node.pivot := p
#             node.label := None
#           node.son1 := fit_balltree(xl,yl),
#           node.son2 := fit_balltree(xr,yr),
#           node.radius := radius
#       return node
#   end if
# end function

G=nx.Graph()
def fit_ball_tree(x,y):
    if x.shape[0]==1:
        node=dict({'pivot':x[0],
                   'label':y[0],
                   'son1':None,
                   'son2':None,
                   'radius':0
                   })
        return node
    c=np.argmax(np.std(x,axis=0)) #c为最宽的维度
    p1=x[np.argmin(x[:,c])]
    p2=x[np.argmax(x[:,c])]
    p=(p1+p2)/2 #p为c维度的中心点
    radius=max(np.linalg.norm(x-p,axis=1)) #p到各点的最大距离(球半径)
    x1,y1,x2,y2=[],[],[],[]
    # 根据x中各点到p1和p2的距离,将x分为两个子集
    for i in range(x.shape[0]):
        if np.linalg.norm(x[i]-p1)=np.max(q):
        return
    if node['son1']==None and node['son2']==None:
        if Q.shape[0]==k:
            Q=np.delete(Q,np.argmax(q),axis=0)
            q=np.delete(q,np.argmax(q))
        Q=np.append(Q,[node['pivot']],axis=0)
        q=np.append(q,np.linalg.norm(t-node['pivot']))
    else:
        ball_tree_search(k,t,node['son1'])
        ball_tree_search(k,t,node['son2'])

k=3
t=np.array([6,3])
ball_tree_search(k,t,root)
print(Q)
fig.axes[0].scatter(t[0],t[1],marker='*',s=500,color='green')
fig.axes[0].scatter(Q[:,0],Q[:,1],marker='*',s=500,facecolors='none',edgecolors='green')

你可能感兴趣的:(KNN的核心算法kd-tree和ball-tree)