KNN(k-nearest neighbor algorithm)--从原理到实现


分类: 机器学习 代码   767人阅读  评论(0)  收藏  举报

目录(?)[+]

零.广告

        本文所有代码实现均可以在 DML 找到,不介意的话请大家在github里给我点个Star大笑

一.引入 

K近邻算法作为数据挖掘十大经典算法之一,其算法思想可谓是intuitive,就是从训练集里找离预测点最近的K个样本来预测分类

        因为算法思想简单,你可以用很多方法实现它,这时效率就是我们需要慎重考虑的事情,最简单的自然是求出测试样本和训练集所有点的距离然后排序选择前K个,这个是O(nlogn)的,而其实从N个数据找前K个数据是一个很常见的算法题,可以用最大堆(最小堆)实现,其效率是O(nlogk)的,而最广泛的算法是使用kd树来减少扫描的点,这也就是这篇文章的主要内容,本文偏实现,详细理论教程见july的文章 ,不得不服,july这篇文章巨细无遗!

二.前提:堆的实现

        堆是一种二叉树,用一个数组存储,对于k号元素,k*2号是其左儿子,k*2+1号是其右儿子

        而大根堆就是跟比左儿子和右儿子都大,小根堆反之。

        要满足这个条件我们需要通过up( index )操作和down( index )维护它的结构

        当然讲这个的文章实在有些多了,随便搜一篇大家看看:点击打开链接

        大小根堆的作用是

              a) 优先队列:因为第一个元素是最大或者最小的元素,所以可以实现优先队列

              b) 前K个最大(最小)值:这里限制堆的大小为k,来获得O( n log k)的效率,但注意此时小根堆是获得前K个最大值,大根堆是获得前K个最小值,插入的时候先把元素和堆顶比较再决定是否插入。

        因为事先KD-tree+BBF 要同时用到这两个东西,所以把它们实现在了同一个类里,感觉代码略漂亮,贴出来观赏一下:

        此代码是dml / tool / heap.py

  

[python]  view plain copy
  1. from __future__ import division  
  2. import numpy as np  
  3. import scipy as sp  
  4. def heap_judge(a,b):  
  5.         return a>b  
  6.   
  7. class Heap:  
  8.         def __init__(self,K=None,compare=heap_judge):  
  9.                 ''''' 
  10.                         'K'                 is the parameter to restrict the length of Heap 
  11.                                                 !!! when K is confirmed,the Min heap contain Max K elements 
  12.                                                                   while Max heap contain Min K elements 
  13.                         'compare'         is the compare function which return a BOOL when pass two variable 
  14.                                                 default is Max heap 
  15.                 '''  
  16.                 self.K=K  
  17.                 self.compare=compare  
  18.                 self.heap=['#']  
  19.                 self.counter=0  
  20.         def insert(self,a):  
  21.                 #print self.heap  
  22.                 if self.K!=None:  
  23.                         print a.x,'==='  
  24.                 if self.K==None:  
  25.                         self.heap.append(a)  
  26.                         self.counter+=1  
  27.                         self.up(self.counter)  
  28.                 else:  
  29.                         if self.counter<self.K:  
  30.                                 self.heap.append(a)  
  31.                                 self.counter+=1  
  32.                                 self.up(self.counter)  
  33.                         else:  
  34.                                 if (not self.compare(a,self.heap[1])):  
  35.                                         self.heap[1]=a  
  36.                                         self.down(1)  
  37.                 return  
  38.         def up(self,index):  
  39.                 if (index==1):  
  40.                         return  
  41.                 ''''' 
  42.                 print index 
  43.                 for t in range(index+1): 
  44.                         if t==0: 
  45.                                 continue 
  46.                         print self.heap[t].x 
  47.                 print  
  48.                 '''  
  49.                 if self.compare(self.heap[index],self.heap[int(index/2)]):  
  50.                         #fit the condition  
  51.                         self.heap[index],self.heap[int(index/2)]=self.heap[int(index/2)],self.heap[index]  
  52.                         self.up(int(index/2))  
  53.                 return  
  54.         def down(self,index):  
  55.                 if 2*index>self.counter:  
  56.                         return  
  57.                 tar_index=0  
  58.                 if 2*index<self.counter:  
  59.                         if self.compare(self.heap[index*2],self.heap[index*2+1]):  
  60.                                 tar_index=index*2  
  61.                         else:  
  62.                                 tar_index=index*2+1  
  63.                 else:  
  64.                         tar_index=index*2  
  65.                 if not self.compare(self.heap[index],self.heap[tar_index]):  
  66.                         self.heap[index],self.heap[tar_index]=self.heap[tar_index],self.heap[index]  
  67.                         self.down(tar_index)  
  68.                 return  
  69.   
  70.         def delete(self,index):  
  71.                 self.heap[index],self.heap[self.counter]=self.heap[self.counter],self.heap[index]  
  72.                 self.heap.pop()  
  73.                 self.counter-=1  
  74.                 self.down(index)  
  75.                 pass  
  76.   
  77.         def delete_ele(self,a):  
  78.                 try:  
  79.                         t=self.heap.index(a)  
  80.                 except ValueError:  
  81.                         t=None  
  82.                 if t!=None:  
  83.                         self.delete(t)  
  84.                 return t  
           传入的时候不设置K就是正常的优先队列,设置了K就是限制堆的大小了

          compare参数是比较大小的,默认是“数”的大根堆,你可以往堆里传任何类,只要有相适应的compare参数,比如我们KD-tree传的就是KD-Node

        

三.KD-BFF的原理:

          首先从KD-Tree的创建说起:(直接贴《统计学习方法》的内容了)

          KNN(k-nearest neighbor algorithm)--从原理到实现_第1张图片

          事实上从选择哪一个feature开始切割,还可以选择方差最大的那个参数,但是考虑到简便,以及我们可以选择更多的相似性度量方法,还是用《统计学习方法》里面的选择方式了。

          然后是KD-tree搜索的方法:(来自《统计学习方法》,但注意这里是最近邻,也就是k=1的时候)

         KNN(k-nearest neighbor algorithm)--从原理到实现_第2张图片

          那么我们要K近邻要怎么做呢?就是用堆的第二个应用,用大根堆保持K个最小的距离,然后用根的距离(也就是其中最大的一个)来作为判断的依据是否有更近的点不在结果中,这一点很重要!

          同时摘录july博客的一段读者留言讲得非常好的:

              在某一层,分割面是第ki维,分割值是kv,那么 abs(q[ki]-kv) 就是没有选择的那个分支的优先级,也就是计算的是那一维上的距离; 同时,从优先队列里面取节点只在某次搜索到叶节点后才发生,计算过距离的节点不会出现在队列的,比如1~10这10个节点,你第一次搜索到叶节点的路径是1-5-7,那么1,5,7是不会出现在优先队列的。换句话说,优先队列里面存的都是查询路径上节点对应的相反子节点,比如:搜索左子树,就把对应这一层的右节点存进队列。

          大致这就是我们实现的基本思路了

四.KD-BFF的实现:

        知道原理了,并且有了堆这个工具之后我们就可以着手实现这个算法了:(终于要贴代码了)

       代码~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~此代码是 dml / KNN / kd.py

         

[python]  view plain copy
  1. from __future__ import division  
  2. import numpy as np  
  3. import scipy as sp  
  4. from operator import itemgetter  
  5. from scipy.spatial.distance import euclidean  
  6. from dml.tool import Heap  
  7. class KDNode:  
  8.     def __init__(self,x,y,l):  
  9.         self.x=x  
  10.         self.y=y  
  11.         self.l=l  
  12.         self.F=None  
  13.         self.Lc=None  
  14.         self.Rc=None  
  15.         self.distsToNode=None  
  16.   
  17. class KDTree:  
  18.     def __init__(self,X,y=None,dist=euclidean):  
  19.         self.X=X  
  20.         self.k=X.shape[0#N  
  21.         self.y=y  
  22.         self.dist=dist  
  23.         self.P=self.maketree(X,y,0)  
  24.         self.P.F=None  
  25.     def maketree(self,data,y,deep):  
  26.         if data.size==0:  
  27.             return None  
  28.         lenght = data.shape[0]  
  29.         case = data.shape[1]  
  30.         p=int((case)/2)  
  31.         l = (deep%self.k)  
  32.         #print data  
  33.         data=np.vstack((data,y))  
  34.         data=np.array(sorted(data.transpose(),key=itemgetter(l))).transpose()  
  35.         #print data  
  36.         y=data[lenght,:]  
  37.         data=data[:lenght,:]  
  38.           
  39.         v=data[l,p]  
  40.         rP=KDNode(data[:,p],y[p],l)  
  41.         #print data[:,p],y[p],l  
  42.         if case>1:  
  43.             ldata=data[:,data[l,:]<v]  
  44.             ly=y[data[l,:]<v]  
  45.             data[l,p]=v-1  
  46.             rdata=data[:,data[l,:]>=v]  
  47.             ry=y[data[l,:]>=v]  
  48.             data[l,p]=v  
  49.             rP.Lc=self.maketree(ldata,ly,deep+1)  
  50.             if rP.Lc!=None:  
  51.                 rP.Lc.F=rP  
  52.             rP.Rc=self.maketree(rdata,ry,deep+1)  
  53.             if rP.Rc!=None:  
  54.                 rP.Rc.F=rP  
  55.         return rP  
  56.   
  57.     def search_knn(self,P,x,k,maxiter=200):  
  58.         def pf_compare(a,b):  
  59.             return self.dist(x,a.x)<self.dist(x,b.x)  
  60.         def ans_compare(a,b):  
  61.             return self.dist(x,a.x)>self.dist(x,b.x)  
  62.         pf_seq=Heap(compare=pf_compare)  
  63.         pf_seq.insert(P)    #prior sequence  
  64.         ans=Heap(k,compare=ans_compare)  #ans sequence  
  65.         while pf_seq.counter>0:  
  66.             t=pf_seq.heap[1]  
  67.             pf_seq.delete(1)  
  68.             flag=True  
  69.             if ans.counter==k:  
  70.                 now=t.F  
  71.                 #print ans.heap[1].x,'========'  
  72.                 if now != None:  
  73.                     q=x.copy()  
  74.                     q[now.l]=now.x[now.l]  
  75.                     length=self.dist(q,x)  
  76.                     if length>self.dist(ans.heap[1].x,x):  
  77.                         flag=False  
  78.                     else:  
  79.                         flag=True  
  80.                 else:  
  81.                     flag=True  
  82.             if flag:  
  83.                 tp,pf_seq,ans=self.to_leaf(t,x,pf_seq,ans)  
  84.             #print "============="  
  85.             #ans.insert(tp)  
  86.         return ans  
  87.   
  88.   
  89.     def to_leaf(self,P,x,pf_seq,ans):  
  90.         tp=P  
  91.         if tp!=None:  
  92.             ans.insert(tp)  
  93.             if tp.x[tp.l]>x[tp.l]:  
  94.                 if tp.Rc!=None:  
  95.                     pf_seq.insert(tp.Rc)  
  96.                 if tp.Lc==None:  
  97.                     return tp,pf_seq,ans  
  98.                 else:  
  99.                     return self.to_leaf(tp.Lc,x,pf_seq,ans)  
  100.             if tp.Lc!=None:  
  101.                 pf_seq.insert(tp.Lc)  
  102.             if tp.Rc==None:  
  103.                     return tp,pf_seq,ans  
  104.             else:  
  105.                     return self.to_leaf(tp.Rc,x,pf_seq,ans)  

       

        然后KNN就是对上面这个类的一个包装

        代码~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~此代码是 dml / KNN / knn.py

    

[python]  view plain copy
  1. #coding:utf-8   
  2. import numpy as np  
  3. import scipy as sp  
  4. from scipy.spatial.distance import cdist  
  5. from scipy.spatial.distance import euclidean  
  6. from dml.KNN.kd import KDTree  
  7.   
  8. #import pylab as py  
  9. class KNNC:  
  10.     """docstring for KNNC"""  
  11.     def __init__(self,X,K,labels=None,dist=euclidean):  
  12.         ''''' 
  13.             X is a N*M matrix where M is the case  
  14.             labels is prepare for the predict. 
  15.             dist is the similarity measurement way, 
  16.  
  17.             The distance function can be ‘braycurtis’, ‘canberra’,  
  18.             ‘chebyshev’, ‘cityblock’, ‘correlation’, ‘cosine’,  
  19.             ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’,  
  20.             ‘mahalanobis’,  
  21.  
  22.         '''  
  23.         self.X = np.array(X)  
  24.         if labels==None:  
  25.             np.zeros((1,self.X.shape[1]))  
  26.         self.labels = np.array(labels)  
  27.         self.K = K  
  28.         self.dist = dist  
  29.         self.KDTrees=KDTree(X,labels,self.dist)  
  30.   
  31.     def predict(self,x,k):  
  32.         ans=self.KDTrees.search_knn(self.KDTrees.P,x,k)  
  33.         dc={}  
  34.         maxx=0  
  35.         y=0  
  36.         for i in range(ans.counter+1):  
  37.             if i==0:  
  38.                 continue  
  39.             dc.setdefault(ans.heap[i].y,0)  
  40.             dc[ans.heap[i].y]+=1  
  41.             if dc[ans.heap[i].y]>maxx:  
  42.                 maxx=dc[ans.heap[i].y]  
  43.                 y=ans.heap[i].y  
  44.         return y  
  45.     def pred(self,test_x,k=None):  
  46.         ''''' 
  47.             test_x is a N*TM matrix,and indicate TM test case 
  48.             you can redecide the k 
  49.         '''  
  50.         if k==None:  
  51.             k=self.K  
  52.         test_case=np.array(test_x)  
  53.         y=[]  
  54.         for i in range(test_case.shape[1]):  
  55.             y.append(self.predict(test_case[:,i].transpose(),k))  
  56.         return y  

       因为KNN毕竟是一个分类算法,所以我在predict是加上了分类的代码,如果只想检验Kd-tree的话,你可以直接用for_point()找最近k个点


五.测试+后记

       测试:

       我们选取《统计学习方法》上面的例子:

      KNN(k-nearest neighbor algorithm)--从原理到实现_第3张图片

       使用代码:

          

[python]  view plain copy
  1. X=np.array([[2,5,9,4,8,7],[3,4,6,7,1,2]])  
  2. y=np.array([2,5,9,4,8,7])  
  3. knn=KNNC(X,1,y)  
  4. print knn.for_point([[6.5],[7]],1)  
这里y是label,是用来预测的,这个例子里没有实际作用,这是用来分类的

       输出中后面带了“===”的是扫描过的点,最后的是搜索的结果:

       KNN(k-nearest neighbor algorithm)--从原理到实现_第4张图片

       我们可以看到的确避免扫描了(2,3),Bingo!!

       我们再knn.for_point([[2],[2]]):可以看到避免扫了很多点!!!

      

      

       后记:

       从实现写此文前后耗时两天,昨天写代码写到熄灯且刚好测试通过,怎一个爽字了得!!最后,再在github上求个Star

reference:

【1】从K近邻算法、距离度量谈到KD树、SIFT+BBF算法 http://blog.csdn.net/v_july_v/article/details/8203674

【2】《统计学习方法》 李航

【3】最大堆的插入/删除/调整/排序操作(图解+程序)  http://www.java3z.com/cwbwebhome/article/article1/1362.html?id=4745


你可能感兴趣的:(代码,机器学习)