K Nearest Neighbor问题的解决——KD-TREE Implementation

命题一:  
已知的1000个整数的数组,给定一个整数,要求查证是否在数组中出现? 

命题二:  
已知1000个整数的数组,给定一个整数,要求查找数组中与之最接近的数字? 

命题三:  
已知1000个Point(包含X与Y坐标)结构的数组,给定一个Point,要求查找数组中与之最接近(比如:欧氏距离最短)的点。 

命题四:  
已知1,000,000个向量,每个向量为128维;给定一个向量,要求查找数组中与之最接近的K个向量 

  • 对于命题一,如果不考虑桶式、哈希等方式,常用的方法应该是排序后,使用折半查找。
  • 对于命题二,与命题一类似,比较折半查找得出的结果,以及附近的各一个元素,即可。整个过程相当于是把这个包含1000个数组的数据结构做成一颗二叉树,最后只需比较叶子节点与其父节点即可。
  • 对于命题三、四其中命题三和四就是所谓的Nearest Neighbor问题。一种近似解决的方法就是KD-TREE


高维向量的KNN检索问题,在图像等多媒体内容搜索中是相当关键的。关于高维向量的讨论,网上资料比较少;在此,我将一些心得分享给大家。 
与二叉树相比,KD-TREE也采用类似的划分方式,只不过树中的各节点均是高维向量,因此划分的方式,采用随机或指定的方式选取一个维度,在该指定维度上进行划分;整体的思想就是采用多个超平面对数据集空间进行两两切分,这一点,有点类似于数据挖掘中的决策树。 

一个运用KD-TREE分割二维平面的DEMO如下: 

K Nearest Neighbor问题的解决——KD-TREE Implementation_第1张图片  

KD-Tree build的代码如下: 
Java代码   收藏代码
  1. private ClusterKDTree(Clusterable[] points, int height, boolean randomSplit){  
  2.     if ( points.length == 1 ){  
  3.         cluster = points[0];  
  4.     }  
  5.     else {  
  6.         splitIndex = chooseSplitDimension//选取切分维度  
  7.             (points[0].getLocation().length,height,randomSplit);  
  8.         splitValue = chooseSplit(points,splitIndex);//选取切分值  
  9.               
  10.         Vector<Clusterable> left = new Vector<Clusterable>();  
  11.         Vector<Clusterable> right = new Vector<Clusterable>();  
  12.         for ( int i = 0; i < points.length; i++ ){  
  13.             double val = points[i].getLocation()[splitIndex];  
  14.             if ( val == splitValue && cluster == null ){  
  15.                 cluster = points[i];  
  16.             }  
  17.             else if ( val >= splitValue ){  
  18.                 right.add(points[i]);  
  19.             } else {  
  20.                 left.add(points[i]);  
  21.             }  
  22.         }  
  23.               
  24.         if ( right.size() > 0 ){  
  25.             this.right = new ClusterKDTree(right.toArray(new  
  26.             Clusterable[right.size()]),  
  27.             randomSplit ? splitIndex : height+1, randomSplit);  
  28.         }  
  29.         if ( left.size() > 0 ){  
  30.             this.left = new ClusterKDTree(left.toArray(new  
  31.             Clusterable[left.size()]),randomSplit ? splitIndex : height+1,  
  32.             randomSplit);  
  33.         }  
  34.     }  
  35. }  
  36.   
  37. private int chooseSplitDimension(int dimensionality,int height,boolean random){  
  38.     if ( !random ) return height % dimensionality;  
  39.     int rand = r.nextInt(dimensionality);  
  40.     while ( rand == height ){  
  41.         rand = r.nextInt(dimensionality);  
  42.     }  
  43.     return rand;  
  44. }  
  45.       
  46. private double chooseSplit(Clusterable points[],int splitIdx){  
  47.     double[] values = new double[points.length];  
  48.     for ( int i = 0; i < points.length; i++ ){  
  49.     values[i] = points[i].getLocation()[splitIdx];  
  50.     }  
  51.     Arrays.sort(values);  
  52.     return values[values.length/2];//选取中间值以保持树的平衡  
  53. }  


构建完一颗KD-TREE之后,如何使用它来做KNN检索呢?我用下面的图来表示(20s的GIF动画): 

K Nearest Neighbor问题的解决——KD-TREE Implementation_第2张图片

使用KD-TREE,经过一次二分查找可以获得Query的KNN(最近邻)贪心解,代码如下: 
Java代码   收藏代码
  1. private Clusterable restrictedNearestNeighbor(Clusterable point, SizedPriorityQueue<ClusterKDTree> values){  
  2.     if ( splitIndex == -1 ) {  
  3.         return cluster; //已近到叶子节点  
  4.     }  
  5.           
  6.     double val = point.getLocation()[splitIndex];  
  7.     Clusterable closest = null;  
  8.     if ( val >= splitValue && right != null || left == null ){  
  9.         //沿右边路径遍历,并将左边子树放进队列  
  10.         if ( left != null ){  
  11.             double dist = val - splitValue;  
  12.             values.add(left,dist);  
  13.         }  
  14.         closest = right.restrictedNearestNeighbor(point,values);  
  15.     }  
  16.     else if ( val < splitValue && left != null || right == null ) {  
  17.         //沿左边路径遍历,并将右边子树放进队列  
  18.         if ( right != null ){  
  19.             double dist = splitValue - val;  
  20.             values.add(right,dist);  
  21.         }  
  22.         closest = left.restrictedNearestNeighbor(point,values);  
  23.     }  
  24.     //current distance of the 'ideal' node  
  25.     double currMinDistance = ClusterUtils.getEuclideanDistance(closest,point);  
  26.     //check to see if the current node we've backtracked to is closer  
  27.     double currClusterDistance = ClusterUtils.getEuclideanDistance(cluster,point);  
  28.     if ( closest == null || currMinDistance > currClusterDistance ){  
  29.         closest = cluster;  
  30.         currMinDistance = currClusterDistance;  
  31.     }  
  32.     return closest;  
  33. }  


事实上,仅仅一次的遍历会有不小的误差,因此采用了一个优先级队列来存放每次决定遍历走向时,另一方向的节点。SizedPriorityQueue代码的实现,可参考我的另一篇文章: 
http://grunt1223.iteye.com/blog/909739  

一种减少误差的方法(BBF:Best Bin First)是回溯一定数量的节点: 
Java代码   收藏代码
  1. public Clusterable restrictedNearestNeighbor(Clusterable point, int numMaxBinsChecked){  
  2.     SizedPriorityQueue<ClusterKDTree> bins = new SizedPriorityQueue<ClusterKDTree>(50,true);  
  3.     Clusterable closest = restrictedNearestNeighbor(point,bins);  
  4.     double closestDist = ClusterUtils.getEuclideanDistance(point,closest);  
  5.     //System.out.println("retrieved point: " + closest + ", dist: " + closestDist);  
  6.     int count = 0;  
  7.     while ( count < numMaxBinsChecked && bins.size() > 0 ){  
  8.         ClusterKDTree nextBin = bins.pop();  
  9.     //System.out.println("Popping of next bin: " + nextBin);  
  10.     Clusterable possibleClosest = nextBin.restrictedNearestNeighbor(point,bins);  
  11.         double dist = ClusterUtils.getEuclideanDistance(point,possibleClosest);  
  12.         if ( dist < closestDist ){  
  13.         closest = possibleClosest;  
  14.         closestDist = dist;  
  15.     }  
  16.     count++;  
  17.     }  
  18.     return closest;  
  19. }  


可以用如下代码进行测试: 
Java代码   收藏代码
  1. public static void main(String args[]){  
  2.     Clusterable clusters[] = new Clusterable[10];  
  3.     clusters[0] = new Point(0,0);  
  4.     clusters[1] = new Point(1,2);  
  5.     clusters[2] = new Point(2,3);  
  6.     clusters[3] = new Point(1,5);  
  7.     clusters[4] = new Point(2,5);  
  8.     clusters[5] = new Point(1,1);  
  9.     clusters[6] = new Point(3,3);  
  10.     clusters[7] = new Point(0,2);  
  11.     clusters[8] = new Point(4,4);  
  12.     clusters[9] = new Point(5,5);  
  13.     ClusterKDTree tree = new ClusterKDTree(clusters,true);  
  14.     //tree.print();  
  15.     Clusterable c = tree.restrictedNearestNeighbor(new Point(4,4),1000);  
  16.     System.out.println(c);  
  17. }  

你可能感兴趣的:(K Nearest Neighbor问题的解决——KD-TREE Implementation)