KNN算法与Kd树(转载+代码详细解释)

最近邻法和k-近邻法
  下面图片中只有三种豆,有三个豆是未知的种类,如何判定他们的种类?

KNN算法与Kd树(转载+代码详细解释)_第1张图片

提供一种思路,即:未知的豆离哪种豆最近就认为未知豆和该豆是同一种类。由此,我们引出最近邻算法的定义:为了判定未知样本的类别,以全部训练样本作为代表点,计算未知样本与所有训练样本的距离,并以最近邻者的类别作为决策未知样本类别的唯一依据。但是,最近邻算法明显是存在缺陷的,比如下面的例子:有一个未知形状(图中绿色的圆点),如何判断它是什么形状?
  KNN算法与Kd树(转载+代码详细解释)_第2张图片

显然,最近邻算法的缺陷——对噪声数据过于敏感,为了解决这个问题,我们可以可以把未知样本周边的多个最近样本计算在内,扩大参与决策的样本量,以避免个别数据直接决定决策结果。由此,我们引进K-最近邻算法。K-最近邻算法是最近邻算法的一个延伸。
  
基本思路是:
  选择未知样本一定范围内确定个数的K个样本,该K个样本大多数属于某一类型,则未知样本判定为该类型。如何选择一个最佳的K值取决于数据。一般情况下,在分类时较大的K值能够减小噪声的影响,但会使类别之间的界限变得模糊。待测样本(绿色圆圈)既可能分到红色三角形类,也可能分到蓝色正方形类。如果k取3,从图可见,待测样本的3个邻居在实线的内圆里,按多数投票结果,它属于红色三角形类。但是如果k取5,那么待测样本的最邻近的5个样本在虚线的圆里,按表决法,它又属于蓝色正方形类。在实际应用中,K先取一个比较小的数值,再采用交叉验证法来逐步调整K值,最终选择适合该样本的最优的K值。

KNN算法实现 
算法基本步骤:

1)计算待分类点与已知类别的点之间的距离

2)按照距离递增次序排序

3)选取与待分类点距离最小的k个点

4)确定前k个点所在类别的出现次数

5)返回前k个点出现次数最高的类别作为待分类点的预测分类

下面是一个按照算法基本步骤用python实现的简单例子,根据已分类的4个样本点来预测未知点(图中的灰点)的分类:
  KNN算法与Kd树(转载+代码详细解释)_第3张图片

KNN.py代码如下:

#-*- encoding:utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
from numpy import *  
# create a dataset which contains 4 samples with 2 classes  



#
# 算法基本步骤(☆☆☆☆☆☆☆☆):
# 1)计算待分类点与已知类别的点之间的距离
# 2)按照距离递增次序排序
# 3)选取与待分类点(这个也就是测试集)距离最小的k个点
# 4)确定前k个点所在各自类别的出现次数
# 5)返回前k个点出现次数最高的类别作为待分类点的预测分类
def createDataSet():  
    # create a matrix: each row as a sample  
    group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])  
    labels = ['A', 'A', 'B', 'B'] # 这里是在给训练集打标签  
    return group, labels
# classify using kNN (k Nearest Neighbors )  
# Input:      newInput: 1 x N
#             dataSet:  M x N (M samples N, features)
#             labels:   1 x M   
#             k: number of neighbors to use for comparison  
# Output:     the most popular class label   


def kNNClassify(newInput, dataSet, labels, k):  
    #labels是个列表
    print "dataSet=",dataSet#dataSet这里就是训练集

    print"newInput=",newInput#newInput这里就是测试集(这里比较特殊,只有一个数据)
    numSamples = dataSet.shape[0] # shape[0]用来读取矩阵第一维的长度
    print"dataSet.shape[0]=",dataSet.shape[0] 



    ## step 1: calculate Euclidean distance (欧氏距离) 
    # tile(A, reps): Construct an array by repeating A reps times  
    # the following copy numSamples rows for dataSet  
    diff = tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise  
    squaredDiff = diff ** 2 # 测试数据与训练集中的每个点的差值的平方,这些平方组成的一个数组  
    print"-"*50
    print"squareDiff=",squaredDiff
    print"-"*50
    squaredDist = sum(squaredDiff, axis = 1) # 将矩阵的每一行向量相加,也即是说x^2+y^2
    print"squaredDist=",squaredDist
    distance = squaredDist ** 0.5 #这里是在计算平方根 
    print"distance=",distance

  
    ## step 2: sort the distance  
    # argsort() returns the indices that would sort an array in a ascending order  
    sortedDistIndices = argsort(distance)
    print"sortedDistIndices=",sortedDistIndices#根据被测点与测试集中各个点的距离的不同进行排序。

###################################################################  
  
    classCount = {} # define a dictionary (can be append element)  
    for i in xrange(k):  #这里应该是取得前面k个距离最近的点。
        ## step 3: choose the min k distance  
        voteLabel = labels[sortedDistIndices[i]] #遍历sortedDistIndices的前面k个点,然后获取这k个点的分类标签
        print"voteLabel=",voteLabel 
  
        ## step 4: count the times labels occur  
        # when the key voteLabel is not in dictionary classCount, get()  
        # will return 0  
        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1  #对k个数据中,具备各种抱歉的数据进行统计,然后写入
        #上面这句话中get(voteLabel, 0)中之所以有个0的原因是:如果指定键的值不在字典中返回指定值0,
        #如果存在,那么就当前的键值+1
    ## step 5: the max voted class will return  
    maxCount = 0  
    for key, value in classCount.items():  
        if value > maxCount:  
            maxCount = value
            maxIndex = key  
              
  
    return maxIndex#这个返回的就是最大类别。最后其实就是k个最近的点中,哪个类别多,那么就被判别为哪一类。(少数服从多数)   
    
if __name__== "__main__":    
    dataSet, labels = createDataSet()  

###############第1例测试
    testX = array([1.2, 1.0])#查找点(1.2,1.0)待会儿查下这个预测准确吗?
    k = 3  
    outputLabel = kNNClassify(testX, dataSet, labels, 3)  
    print "Your input is:", testX, "and classified to class: ", outputLabel  
 ###############第2例测试    
    testX = array([0.1, 0.3])#查找点(0.1,0.3)  
    outputLabel = kNNClassify(testX, dataSet, labels, 3)  
    print "Your input is:", testX, "and classified to class: ", outputLabel

结果如下:
Your input is: [ 1.2 1. ] and classified to class: A
Your input is: [ 0.1 0.3] and classified to class: B

OpenCV中也提供了机器学习的相关算法,其中KNN算法的最基本例子opencv_KNN.py如下:

#-*- encoding:utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import numpy as np
import matplotlib.pyplot as plt
import cv2
# Feature set containing (x,y) values of 25 known/training data
trainData = np.random.randint(0,100,(25,2)).astype(np.float32)
# Labels each one either Red or Blue with numbers 0 and 1
responses = np.random.randint(0,2,(25,1)).astype(np.float32)
# Take Red families and plot them
red = trainData[responses.ravel()==0]
plt.scatter(red[:,0],red[:,1],80,'r','^')
# Take Blue families and plot them
blue = trainData[responses.ravel()python==1]
plt.scatter(blue[:,0],blue[:,1],80,'b','s')
# Testing data
newcomer = np.random.randint(0,100,(1,2)).astype(np.float32)
plt.scatter(newcomer[:,0],newcomer[:,1],80,'g','o')

knn = cv2.KNearest()
knn.train(trainData,responses) # Trains the model
# Finds the neighbors and predicts responses for input vectors.
ret, results, neighbours ,dist = knn.find_nearest(newcomer, 3)
print "result: ", results,"\n"print "neighbours: ", neighbours,"\n"print "distance: ", dist
plt.show()

result: [[ 0.]]
neighbours: [[ 0. 0. 0.]]
distance: [[ 65. 145. 178.]]

可以看到KNN算法将未知点分到第0组(红色三角形组),从上图中也可看出3个距离未知点最近的样本都属于第0组,因此算法返回分类标签也为0。

KNN算法的缺陷
  观察下面的例子,我们看到对于样本X,通过KNN算法,我们显然可以得到X应属于红点,但对于样本Y,通过KNN算法我们似乎得到了Y应属于蓝点的结论,而这个结论直观来看并没有说服力。

由上面的例子可见:该算法在分类时有个重要的不足是,当样本不平衡时,即:一个类的样本容量很大,而其他类样本数量很小时,很有可能导致当输入一个未知样本时,该样本的K个邻居中大数量类的样本占多数。 但是这类样本并不接近目标样本,而数量小的这类样本很靠近目标样本。这个时候,我们有理由认为该位置样本属于数量小的样本所属的一类,但是,KNN却不关心这个问题,它只关心哪类样本的数量最多,而不去把距离远近考虑在内,因此,我们可以采用权值的方法来改进。和该样本距离小的邻居权值大,和该样本距离大的邻居权值则相对较小,由此,将距离远近的因素也考虑在内,避免因一个样本过大导致误判的情况。

从算法实现的过程可以发现,该算法存两个严重的问题,第一个是需要存储全部的训练样本,第二个是计算量较大,因为对每一个待分类的样本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。KNN算法的改进方法之一是分组快速搜索近邻法。其基本思想是:将样本集按近邻关系分解成组,给出每组质心的位置,以质心作为代表点,和未知样本计算距离,选出距离最近的一个或若干个组,再在组的范围内应用一般的KNN算法。由于并不是将未知样本与所有样本计算距离,故该改进算法可以减少计算量,但并不能减少存储量。

KD树
  实现k近邻法时,主要考虑的问题是如何对训练数据进行快速k近邻搜索。这在特征空间的维数大及训练数据容量大时尤其必要。k近邻法最简单的实现是线性扫描(穷举搜索),即要计算输入实例与每一个训练实例的距离。计算并存储好以后,再查找K近邻。当训练集很大时,计算非常耗时。为了提高kNN搜索的效率,可以考虑使用特殊的结构存储训练数据,以减小计算距离的次数。

kd树(K-dimension tree)是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是是一种二叉树,表示对k维空间的一个划分,构造kd树相当于不断地用垂直于坐标轴的超平面将K维空间切分,构成一系列的K维超矩形区域。kd树的每个结点对应于一个k维超矩形区域。利用kd树可以省去对大部分数据点的搜索,从而减少搜索的计算量。
  KNN算法与Kd树(转载+代码详细解释)_第4张图片

对一个三维空间,kd树按照一定的划分规则把这个三维空间划分了多个空间,如下图所示
  KNN算法与Kd树(转载+代码详细解释)_第5张图片

类比“二分查找”:给出一组数据:[9 1 4 7 2 5 0 3 8],要查找8。
  如果挨个查找(线性扫描),那么将会把数据集都遍历一遍。而如果排一下序那数据集就变成了:[0 1 2 3 4 5 6 7 8 9],按前一种方式我们进行了很多没有必要的查找,现在如果我们以5为分界点,那么数据集就被划分为了左右两个“簇” [0 1 2 3 4]和[6 7 8 9]。
  因此,根本久没有必要进入第一个簇,可以直接进入第二个簇进行查找。把二分查找中的数据点换成k维数据点,这样的划分就变成了用超平面对k维空间的划分。空间划分就是对数据点进行分类,“挨得近”的数据点就在一个空间里面。

构造kd树的方法如下:
  构造根结点,使根结点对应于K维空间中包含所有实例点的超矩形区域;通过下面的递归的方法,不断地对k维空间进行切分,生成子结点。在超矩形区域上选择一个坐标轴和在此坐标轴上的一个切分点,确定一个超平面,这个超平面通过选定的切分点并垂直于选定的坐标轴,将当前超矩形区域切分为左右两个子区域(子结点);这时,实例被分到两个子区域,这个过程直到子区域内没有实例时终止(终止时的结点为叶结点)。在此过程中,将实例保存在相应的结点上。通常,循环的择坐标轴对空间切分,选择训练实例点在坐标轴上的中位数为切分点,这样得到的kd树是平衡的(平衡二叉树:它是一棵空树,或其左子树和右子树的深度之差的绝对值不超过1,且它的左子树和右子树都是平衡二叉树)。

KD树中每个节点是一个向量,和二叉树按照数的大小划分不同的是,KD树每层需要选定向量中的某一维,然后根据这一维按左小右大的方式划分数据。在构建KD树时,关键需要解决2个问题:
  (1)选择向量的哪一维进行划分;
  (2)如何划分数据。第一个问题简单的解决方法可以是选择随机选择某一维或按顺序选择,但是更好的方法应该是在数据比较分散的那一维进行划分(分散的程度可以根据方差来衡量)。好的划分方法可以使构建的树比较平衡,可以每次选择中位数来进行划分,这样问题2也得到了解决。

构造平衡kd树算法:
输入:kk维空间数据集 T = { x 1 , x 2 , . . . , x N } T=\{x_1,x_2,...,x_N\} T={x1,x2,...,xN},其中 x i = ( x i ( 1 ) , x i ( 2 ) , . . . , x i ( k ) ) , i = 1 , 2 , . . . , N ; x_i=(x_i^{(1)},x_i^{(2)},...,x_i^{(k)}),i=1,2,...,N; xi=(xi(1),xi(2),...,xi(k)),i=1,2,...,N;
输出:kd树

(1)开始:构造根结点,根结点对应于包含 T T T k k k维空间的超矩形区域。选择 x ( 1 ) x^{(1)} x(1)为坐标轴,以T中所有实例的 x ( 1 ) x^{(1)} x(1)坐标的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴 x ( 1 ) x^{(1)} x(1)垂直的超平面实现。由根结点生成深度为1的左、右子结点:左子结点对应坐标 x ( 1 ) x^{(1)} x(1)小于切分点的子区域,右子结点对应于坐标 x ( 1 ) x^{(1)} x(1)大于切分点的子区域。将落在切分超平面上的实例点保存在根结点。

(2)重复。对深度为j的结点,选择 x ( l ) x^{(l)} x(l)为切分的坐标轴, l = j % k + 1 l=j\%k+1 l=j%k+1,以该结点的区域中所有实例的 x ( l ) x^{(l)} x(l)坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴 x ( l ) x^{(l)} x(l)垂直的超平面实现。由该结点生成深度为j+1的左、右子结点:左子结点对应坐标 x ( l ) x^{(l)} x(l)小于切分点的子区域,右子结点对应坐标 x ( l ) x^{(l)} x(l)大于切分点的子区域。将落在切分超平面上的实例点保存在该结点。

下面用一个简单的2维平面上的例子来进行说明。

例. 给定一个二维空间数据集: T = { ( 2 , 3 ) , ( 5 , 4 ) , ( 9 , 6 ) , ( 4 , 7 ) , ( 8 , 1 ) , ( 7 , 2 ) } T=\{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)\} T={(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)},构造一个平衡kd树。

解:
  根结点对应包含数据集T的矩形,选择 x ( 1 ) x^{(1)} x(1)轴,6个数据点的 x ( 1 ) x^{(1)} x(1)坐标中位数是6,这里选最接近的(7,2)点,以平面 x ( 1 ) x^{(1)} x(1)=7将空间分为左、右两个子矩形(子结点);
  接着左矩形以 x ( 2 ) x^{(2)} x(2)=4分为两个子矩形(左矩形中{(2,3),(5,4),(4,7)}点的 x ( 2 ) x^{(2)} x(2)坐标中位数正好为4),右矩形以 x ( 2 ) x^{(2)} x(2)=6分为两个子矩形,如此递归,最后得到如下图所示的特征空间划分和kd树。

KNN算法与Kd树(转载+代码详细解释)_第6张图片

下面的代码用递归的方式构建了kd树,通过前序遍历可以进行验证。这里只是简单地采用坐标轮换方式选取分割轴,为了更高效的分割空间,也可以计算所有数据点在每个维度上的数值的方差,然后选择方差最大的维度作为当前节点的划分维度。方差越大,说明这个维度上的数据越不集中(稀疏、分散),也就说明了它们就越不可能属于同一个空间,因此需要在这个维度上进行划分。

KdTree.py

# -*- coding: utf-8 -*-
#from operator import itemgetter
import sys
reload(sys)
sys.setdefaultencoding('utf8')


# kd-tree每个结点中主要包含的数据结构如下 
class KdNode(object):
    def __init__(self, dom_elt, split, left, right):
        self.dom_elt = dom_elt  # k维向量节点(k维空间中的一个样本点)
        self.split = split      # 整数(进行分割维度的序号),奇数层根据x来比较,偶数层根据y来比较
        #这里所谓的分割维度其实就是指根据当前根节点的第几个坐标的值来进行比大小。

        self.left = left        # 该结点分割超平面左子空间构成的kd-tree
        self.right = right      # 该结点分割超平面右子空间构成的kd-tree

class KdTree(object):
    def __init__(self, data):
        k = len(data[0])  # 数据维度,这里其实就是2
        
        def CreateNode(split, data_set): # 按第split维划分数据集exset创建KdNode
        #这里所谓的split维度,讲人话就是(x,y)的第几个坐标,二维的话,split的值分别是0和1



            if not data_set:    # 数据集为空
                return None
            # key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较
            # operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号
            #data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
            data_set.sort(key=lambda x: x[split])#根据第split个坐标进行排序,然后选“中位数”作为下一级子二叉树的根节点
            print"data_set=",data_set
            print"*"*50
            split_pos = len(data_set) // 2      # //为Python中的整数除法
            print"split_pos=",split_pos
            median = data_set[split_pos]        # 中位数分割点    
            print"median=",median        
            split_next = (split + 1) % k        # cycle coordinates(循环坐标)
            #这里循环坐标的意思是,一开始split按照从左往右的方式继续拧比较,从第一个坐标开始比较
            #当split+1等于k的时候,split_next会重新恢复到0
            print"split_next=",split_next
            print"-"*50
            
            # 递归的创建kd树
            return KdNode(median, split, 
                          CreateNode(split_next, data_set[:split_pos]),     # 创建左子树
                          CreateNode(split_next, data_set[split_pos + 1:])) # 创建右子树
            #上买的代码中,一般split_next比split_pos大一,一轮循环结束后,那么会有split_next=0,split_pos在末尾坐标的情况。
        self.root = CreateNode(0, data)         # 从第0维分量开始构建kd树,返回根节点

#KdTree调用了KdNode

# KDTree的前序遍历
def preorder(root):  

    print root.dom_elt  
    if root.left:      # 节点不为空
        preorder(root.left)  
    if root.right:  
        preorder(root.right)  
      
      
if __name__ == "__main__":
    data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    kd = KdTree(data)
    print"*"*100
    preorder(kd.root)

进行前序遍历(前序遍历首先访问根结点然后遍历左子树,最后遍历右子树)的结果如下,可见已经正确构建了kd树:

搜索kd树

利用kd树可以省去对大部分数据点的搜索,从而减少搜索的计算量。下面以搜索最近邻点为例加以叙述:给定一个目标点,搜索其最近邻,首先找到包含目标点的叶节点;然后从该叶结点出发,依次回退到父结点;不断查找与目标点最近邻的结点,当确定不可能存在更近的结点时终止。这样搜索就被限制在空间的局部区域上,效率大为提高。

用kd树的最近邻搜索:  
输入: 已构造的kd树;目标点xx;
输出:xx的最近邻。

(1) 在kd树中找出包含目标点xx的叶结点:从根结点出发,递归的向下访问kd树。若目标点当前维的坐标值小于切分点的坐标值,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止;

(2) 以此叶结点为“当前最近点”;

(3) 递归的向上回退,在每个结点进行以下操作:

(a) 如果该结点保存的实例点比当前最近点距目标点更近,则以该实例点为“当前最近点”;

(b) 当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一个子结点对应的区域是否有更近的点。具体的,检查另一个子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。如果相交,可能在另一个子结点对应的区域内存在距离目标更近的点,移动到另一个子结点。接着,递归的进行最近邻搜索。如果不相交,向上回退。

(4) 当回退到根结点时,搜索结束。最后的“当前最近点”即为xx的最近邻点。

以先前构建好的kd树为例,查找目标点(3,4.5)的最近邻点。
  同样先进行二叉查找,先从(7,2)查找到(5,4)节点,在进行查找时是由y = 4为分割超平面的,由于查找点为y值为4.5,因此进入右子空间查找到(4,7),形成搜索路径:(7,2)→(5,4)→(4,7),取(4,7)为当前最近邻点。
  以目标查找点为圆心,目标查找点到当前最近点的距离2.69为半径确定一个红色的圆。
  然后回溯到(5,4),计算其与查找点之间的距离为2.06,则该结点比当前最近点距目标点更近,以(5,4)为当前最近点。
  用同样的方法再次确定一个绿色的圆,可见该圆和y = 4超平面相交,所以需要进入(5,4)结点的另一个子空间进行查找。(2,3)结点与目标点距离为1.8,比当前最近点要更近,所以最近邻点更新为(2,3),最近距离更新为1.8,同样可以确定一个蓝色的圆。
  接着根据规则回退到根结点(7,2),蓝色圆与x=7的超平面不相交,因此不用进入(7,2)的右子空间进行查找。至此,搜索路径回溯完,返回最近邻点(2,3),最近距离1.8。

KNN算法与Kd树(转载+代码详细解释)_第7张图片

如果实例点是随机分布的,kd树搜索的平均计算复杂度是 O ( l o g N ) O(logN) O(logN),这里N是训练实例数。kd树更适用于训练实例数远大于空间维数时的k近邻搜索。当空间维数接近训练实例数时,它的效率会迅速下降,几乎接近线性扫描。

下面的代码KdTree_search.py对构建好的kd树进行搜索,寻找与目标点最近的样本点:

#-*- encoding:utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
from KdTree import *
#这个代码其实是采用kd tree实现的是1NN,不是KNN

from math import sqrt
from collections import namedtuple

# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple", "nearest_point  nearest_dist  nodes_visited")#就是一个可以塞进去各种类型的元组
#然后括号里面的这四个其实是对象中的成员变量名(例如nearest_point)
#dist,max_dist,temp_dist

#这里面的nodes_visited只是起到一个统计的作用,没有其他的意义。



#下面算法描述来自:
#https://blog.csdn.net/liqiutuoyuan/article/details/77073689

# 回溯查找:
# 根据得到的搜索路径栈,栈顶的元素为‘当前最近点’,
# 将该元素出栈,并计算该点与x的距离d。对于当前栈顶的元素,
# 首先将元素出栈,以x为圆心,d为半径画圆,如果与该元素对应的分割超平面相交,
# 计算该元素和x的距离,如果小于d,则将该元素更新为‘当前最近点’,d也需要更新;
# 如果不相交,则继续对搜索路径的栈顶元素重复相同的操作。
# 同时对元素的另一半子空间对应的子树进行步骤2,搜索的点加入搜索路径。(这里对应于temp2)
# 直到搜索路径栈为空。
# 此时得到的‘当前最近点’即为x的最邻近点,d为最邻近距离。

#注意:理解算法的时候,要把BST的一个分支理解为一个栈

def find_nearest(tree, point):
    k = len(point) # 数据维度
    def travel(kd_node, target, max_dist):#这个的target就是测试数据

        if kd_node!=None:
            print "当前访问节点=",kd_node.dom_elt
        if kd_node is None:#这里是travel的递归结束处,当kd_node为空的时候,结束递归,同时,这里的结果会赋值给temp1和temp2
            return result([0] * k, float("inf"), 0) # python中用float("inf")和float("-inf")表示正负无穷
 
        nodes_visited = 1#访问节点数量统计,对于实现这个算法而言,没啥用。
        
        s = kd_node.split        # 进行分割的维度
        pivot = kd_node.dom_elt  # 这个应该是中位数,同时,pivot是个列表
        
        #注意,下面的比较是建立在kd树已经建立完成的基础上的。
        #经过比较来确定哪个节点离测试点target更近
        if target[s] <= pivot[s]:           # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)
            nearer_node  = kd_node.left     # 下一个访问节点为左子树根节点
            further_node = kd_node.right    # 同时记录下右子树
        else:                               # 目标离右子树更近
            nearer_node  = kd_node.right    # 下一个访问节点为右子树根节点
            further_node = kd_node.left

        print"#########进入temp1##########"
        temp1 = travel(nearer_node, target, max_dist)  # 进行遍历找到包含目标点的区域,这里就是画一个圆了。
        print"##########离开temp1########"
        nearest = temp1.nearest_point       # 以此叶结点作为“当前最近点”
        #这里使用一个列表来保存一个坐标
        dist = temp1.nearest_dist           # 更新最近距离,这里的这个nearest_dist来自上面的namedtuple
        nodes_visited += temp1.nodes_visited
#####################以上是遍历到最下面的叶节#################################################################
        if dist < max_dist:     
            max_dist = dist    # 最近点将在以目标点为球心,max_dist为半径的超球体内
            #一点点缩小max_dist的上限
            
        temp_dist = abs(pivot[s] - target[s])    # 第s维上目标点与分割超平面的距离
        if  max_dist < temp_dist:                # 判断超球体是否与超平面相交
            return result(nearest, dist, nodes_visited) # 不相交则可以直接返回,不用继续判断
            #这里max_dist是球体半径,temp_dist是超平面

#判断球体与超平面是否相割的那个超平面就是根节点的切割维所在的平面。

#为什么这里一会儿是欧氏距离,一会儿又是仅仅计算某个维之间的差值呢(例如两个点之间的xi之间的差值)?
#这是因为欧式距离在整个算法中是用来修正球的半径的,而坐标系之差是用来判断是否相割的。

        #----------------------------------------------------------------------  
        # 计算目标点target与分割点pivot的欧氏距离  
        temp_dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target)))     
        
        if temp_dist < dist:         # 如果“更近”
            nearest = pivot          # 更新最近点
            dist = temp_dist         # 更新最近距离
            max_dist = dist          # 更新超球体半径
        
        # 检查另一个子结点对应的区域是否有更近的点
        print"☆☆☆☆☆☆☆☆进入temp2☆☆☆☆☆☆☆☆☆☆☆"
        temp2 = travel(further_node, target, max_dist) #
        print"☆☆☆☆☆☆☆☆离开temp2☆☆☆☆☆☆☆☆☆☆☆"
        
        nodes_visited += temp2.nodes_visited#访问节点数量统计,没什么用
        if temp2.nearest_dist < dist:        # 如果另一个子结点内存在更近距离
            nearest = temp2.nearest_point    # 更新最近点
            dist = temp2.nearest_dist        # 更新最近距离
 
        return result(nearest, dist, nodes_visited)
 
    return travel(tree.root, point, float("inf"))  # 从根节点开始递归


from time import clock
from random import random

# 产生一个k维随机向量,每维分量值在0~1之间
def random_point(k):
    return [random() for _ in range(k)]
 
# 产生n个k维随机向量 
def random_points(k, n):
    return [random_point(k) for _ in range(n)]       
      
if __name__ == "__main__":
#测试案例2
    data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]  # samples
    kd = KdTree(data)#先建立一个kd树
    ret = find_nearest(kd, [3,4.5])#然后再在kd树中寻找离该点最近的点

#测试案例2
    # N = 40
    # t0 = clock()
    # kd2 = KdTree(random_points(3, N))            # 构建包含四十个3维空间样本点的kd树
    # ret2 = find_nearest(kd2, [0.1,0.5,0.8])      # 四十万个样本点中寻找离目标最近的点
    # t1 = clock()
    # print "time: ",t1-t0, "s"
    # print ret2

#可以参考以下链接来理解:
#http://www.cnblogs.com/eyeszjwang/articles/2429382.html

下面结合前面写的代码来进行一下测试:

from time import clock
from random import random

# 产生一个k维随机向量,每维分量值在0~1之间
def random_point(k):
    return [random() for _ in range(k)]
 
# 产生n个k维随机向量 
def random_points(k, n):
    return [random_point(k) for _ in range(n)]       
      
if __name__ == "__main__":
    data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]  # samples
    
    kd = KdTree(data)
    
    ret = find_nearest(kd, [3,4.5])
    print ret

    N = 400000
    t0 = clock()
    kd2 = KdTree(random_points(3, N))            # 构建包含四十万个3维空间样本点的kd树
    ret2 = find_nearest(kd2, [0.1,0.5,0.8])      # 四十万个样本点中寻找离目标最近的点
    t1 = clock()
    print "time: ",t1-t0, "s"
    print ret2

结果如下图所示。先是测试了之前例子中距离(3,4.5)最近的点,可以看出正确返回了最近点(2,3)以及最近距离。然后随机生成了四十万个三维空间样本点,并构建kd树,然后搜索离(0.1,0.5,0.8)最近的样本点,并测试用时。为了进行对比我先是使用numpy算出全部四十万个距离后寻找最近点,结果耗时0.5s左右!!!怎么能这么快(⊙▽⊙),然后不用numpy自己在python中计算全部距离,结果耗时2s左右,还是比自己写的KD树要快得多…

可能是这种使用递归方式创建和搜索的kd树本身效率就不是很高(知乎:为什么说递归效率低?)。而且深层递归一定要尽量避免,一是不安全,容易导致栈溢出;二是调用代价高(递归函数调用的代价)。可以考虑转换为循环结构。循环结构的kd树实现参考:KDTree example in scipy

参考:

Python手写数字识别-knn算法应用

机器学习算法与Python实践之(一)k近邻(KNN)

《统计学习方法》 李航 第3章 k近邻法

k-d树算法

Kd Tree算法原理和开源实现代码

KD-tree的原理以及构建与查询操作的python实现

从K近邻算法、距离度量谈到KD树、SIFT+BBF算法

KD树详解及KD树最近邻算法

KNN之KD树实现

http://rosettacode.org/wiki/K-d_tree

http://www.cnblogs.com/chuxiuhong/p/5982580.html

#############################################################################################

文章阅读总结:

kd树来实现KNN,
总的来说就是:
KNN利用Kd树来获取离目标点最近的k个测试点,
然后少数服从多数原则,这k个测试点中,哪一类的类别数目最多,
就认为被测试点符合该类别。

这篇文章分别讲了四个代码:
KNN.py 一般KNN算法的思想
opencv_KNN.py opencv中的KNN包的使用
KdTree.py 树的构建
KdTree_search.py 根据构建的树来寻找最近的点,并没有找最近的k个点

文章最后没有把“建立kd树-在kd树中查找最近的k个点-根据标签预测类别”整个流程串起来实现一遍,而是分开实现的。
也就是说,文章最后实现的是“最近邻”(1NN),而不是k近邻(KNN)

你可能感兴趣的:(机器学习算法)