kdtree

kd树类似于geo hash,只不过是对点进行二分,用于范围查询和knn(k-nearest neighbors)查询。
实现如下:
KDTree.java:


import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Stack;




/**
 * 功能描述:  KDTree implementation,


 * 
 * do call init(int dimension) to init the comparator if vector more than two dimension


 *
 * @author : zhaogang.lv


 *
 * @version 2.2 Feb 29, 2012
 *
 * @since dp-searcher 2.2
 */
public class KdTree
{
    private KdTreeNode rootNode ;
    
    private static HPointComparator[] pointComparators ;
    
    public static void init(int dimension)
    {
        pointComparators = new HPointComparator[dimension];
        for(int i=0;i         {
            pointComparators[i] = new HPointComparator(i);
        }
    }
    
    static
    {
        init(2);
    }
    
    protected static double sqrdist(double[] a, double[] b)
    {
        double dist = 0;


        for (int i = 0; i < a.length; ++i)
        {
            double diff = (a[i] - b[i]);
            dist += diff * diff;
        }


        return dist;
    }
    
    public KdTree(List> pointList)
    {
        rootNode = build_kdtree(pointList,0);
    }
    
    private KdTreeNode build_kdtree(List> point_list,int  depth)
    {
        
        if (null==point_list || point_list.size()<=0)
            return null;


        //# select axis based on depth so that axis cycles through all valid values
        final int axis = depth % point_list.get(0).coord.length; //# assumes all points have the same dimension


        //# sort point list and choose median as pivot point,
        Collections.sort(point_list, pointComparators[axis] );
        
        int median = point_list.size()/2 ; // # choose median


        //# create node and recursively construct subtrees
        KdTreeNode node = new KdTreeNode(point_list.get(median), null,
                                         build_kdtree(point_list.subList(0, median), depth+1),
                                         build_kdtree(point_list.subList(median+1,point_list.size()), depth+1)); // XXX: 
        return node;
    }
    
    
    protected static void nnbr(KdTreeNode kd, HPoint target, HRect hr, double max_dist_sqd, int lev, int K,
                                   NearestNeighborList> nnl)
    {


        // 1. if kd is empty then set dist-sqd to infinity and exit.
        if (kd == null)
        {
            return;
        }


        // 2. s := split field of kd
        int s = lev % K;


        HPoint pivot = kd.point;
        double pivot_to_target = sqrdist(pivot.coord, target.coord); 


        // 4. Cut hr into to sub-hyperrectangles left-hr and right-hr.
        // The cut plane is through pivot and perpendicular to the s
        // dimension.
        HRect left_hr = hr; // optimize by not cloning
        HRect right_hr = (HRect) hr.clone();
        left_hr.max.coord[s] = pivot.coord[s];
        right_hr.min.coord[s] = pivot.coord[s];


        // 5. target-in-left := target_s <= pivot_s
        boolean target_in_left = target.coord[s] < pivot.coord[s];


        KdTreeNode nearer_kd;
        HRect nearer_hr;
        KdTreeNode further_kd;
        HRect further_hr;


        // 6. if target-in-left then
        // 6.1. nearer-kd := left field of kd and nearer-hr := left-hr
        // 6.2. further-kd := right field of kd and further-hr := right-hr
        if (target_in_left)
        {
            nearer_kd = kd.left;
            nearer_hr = left_hr;
            further_kd = kd.right;
            further_hr = right_hr;
        }
        //
        // 7. if not target-in-left then
        // 7.1. nearer-kd := right field of kd and nearer-hr := right-hr
        // 7.2. further-kd := left field of kd and further-hr := left-hr
        else
        {
            nearer_kd = kd.right;
            nearer_hr = right_hr;
            further_kd = kd.left;
            further_hr = left_hr;
        }


        // 8. Recursively call Nearest Neighbor with paramters
        // (nearer-kd, target, nearer-hr, max-dist-sqd), storing the
        // results in nearest and dist-sqd
        nnbr(nearer_kd, target, nearer_hr, max_dist_sqd, lev + 1, K, nnl);


//        KdTreeNode nearest = nnl.getHighest();
        double dist_sqd;


        if (!nnl.isCapacityReached())
        {
            dist_sqd = Double.MAX_VALUE;
        }
        else
        {
            dist_sqd = nnl.getMaxPriority();
        }


        // 9. max-dist-sqd := minimum of max-dist-sqd and dist-sqd
        max_dist_sqd = Math.min(max_dist_sqd, dist_sqd);


        // 10. A nearer point could only lie in further-kd if there were some
        // part of further-hr within distance max-dist-sqd of  target.
        HPoint closest = further_hr.closest(target);
        if (sqrdist(closest.coord, target.coord) < max_dist_sqd) 
        {
            // 10.1 if (pivot-target)^2 < dist-sqd then
            if (pivot_to_target < dist_sqd)
            {
                // 10.1.1 nearest := (pivot, range-elt field of kd)
//                nearest = kd;


                // 10.1.2 dist-sqd = (pivot-target)^2
                dist_sqd = pivot_to_target;


//                // add to nnl
//                if (!kd.deleted && ((checker == null) || checker.usable(kd.v)))
//                {
                    nnl.insert(kd, dist_sqd);
//                }


                // 10.1.3 max-dist-sqd = dist-sqd
                // max_dist_sqd = dist_sqd;
                if (nnl.isCapacityReached())
                {
                    max_dist_sqd = nnl.getMaxPriority();
                }
                else
                {
                    max_dist_sqd = Double.MAX_VALUE;
                }
            }


            // 10.2 Recursively call Nearest Neighbor with parameters
            // (further-kd, target, further-hr, max-dist_sqd),
            // storing results in temp-nearest and temp-dist-sqd
            nnbr(further_kd, target, further_hr, max_dist_sqd, lev + 1, K, nnl);
        }
    }


    private NearestNeighborList> getnbrs(double[] key, int n)
    {
        NearestNeighborList> nnl = new NearestNeighborList>(n);


        // initial call is with infinite hyper-rectangle and max distance
        HRect hr = HRect.infiniteHRect(key.length);
        double max_dist_sqd = Double.MAX_VALUE;
        HPoint keyp = new HPoint(key,null);  //XXX
        
        nnbr(rootNode, keyp, hr, max_dist_sqd, 0, key.length, nnl);


        return nnl;


    }
    
    /**
     * Find KD-tree node whose key is nearest neighbor to key.
     * 
     * @param key
     *            key for KD-tree node
     * @return object at node nearest to key, or null on failure
     * @throws KeySizeException
     *             if key.length mismatches K
     */
    public HPoint nearest(double[] key) 
    {


        List> nbrs = nearest(key, 1);
        return nbrs.get(0);
    }
    
    /**
     * Find KD-tree nodes whose keys are n nearest neighbors to key. Uses
     * algorithm above. Neighbors are returned in ascending order of distance to
     * key.
     * 
     * @param key
     *            key for KD-tree node
     * @param n
     *            how many neighbors to find
     * @param checker
     *            an optional object to filter matches
     * @return objects at node nearest to key, or null on failure
     * @throws KeySizeException
     *             if key.length mismatches K
     * @throws IllegalArgumentException
     *             if n is negative or exceeds tree size
     */
    public List> nearest(double[] key, int n)
    {


        if (n <= 0)
        {
            return new LinkedList>();
        }


        NearestNeighborList> nnl = getnbrs(key, n);


        n = nnl.getSize();
        Stack> nbrs = new Stack>();


        for (int i = 0; i < n; ++i)
        {
            KdTreeNode kd = nnl.removeHighest();
            nbrs.push(kd.point);
        }


        return nbrs;
    }
    
    
}


KDTreeNode.java:
public class KdTreeNode
{
    protected HPoint point;
    
//    protected T v;


    protected KdTreeNode left, right;
    
   public KdTreeNode(HPoint point, T val,KdTreeNode  left, KdTreeNode  right)
    {
//        this.v = val;
        this.left = left;
        this.right = right;
        this.point = point;
    }
    
   
   public int getDimension()
   {
       return point.coord.length;
   }
    


}


HPointComparator.java:
import java.util.Comparator;


/**
 * 功能描述:  


 * 
 *
 * @author : zhaogang.lv


 *
 * @version 2.2 Feb 29, 2012
 *
 * @since dp-searcher 2.2
 */
public class HPointComparator implements Comparator
{
    private int axis;
    
    public HPointComparator(int axis)
    {
        this.axis = axis;
    }
    
    @Override
    public int compare(HPoint o1, HPoint o2)
    {
        return Double.compare(o1.coord[axis],o2.coord[axis]);
    }
}


HPoint.java:
public class HPoint
{
    protected T v;
    
    public double[] coord;


    protected HPoint(int n)
    {
        coord = new double[n];
    }


    public HPoint(double[] x)
    {
        coord = new double[x.length];
        for (int i = 0; i < x.length; ++i)
            coord[i] = x[i];
    }
    
    protected HPoint(double[] x, T v)
    {
        this.v = v;
        coord = new double[x.length];
        for (int i = 0; i < x.length; ++i)
            coord[i] = x[i];
    }


    protected Object clone()
    {
        return new HPoint(coord,v);
    }


    protected boolean equals(HPoint p)
    {


        // seems faster than java.util.Arrays.equals(), which is not
        // currently supported by Matlab anyway
        for (int i = 0; i < coord.length; ++i)
            if (coord[i] != p.coord[i])
                return false;


        return true;
    }
    
    public String toString()
    {
        String s = "";
        for (int i = 0; i < coord.length; ++i)
        {
            s = s + coord[i] + " ";
        }
        return s+":"+v;
    }
}


NearestNeighborList.java:
/*
 * Create Author  : zhaogang.lv
 * Create Date     : Feb 29, 2012
 * Project            : dp-searcher
 * File Name        : NearestNeighborList.java
 *
 * Copyright (c) 2010-2015 by Shanghai HanTao Information Co., Ltd.
 * All rights reserved.
 *
 */
package com.dp.arts.lucenex.spatial;


/**
 * 功能描述:  


 * 
 *
 * @author : zhaogang.lv


 *
 * @version 2.2 Feb 29, 2012
 *
 * @since dp-searcher 2.2
 */
public class NearestNeighborList
{


    static class NeighborEntry implements Comparable>
    {
        final T data;


        final double value;


        public NeighborEntry(final T data, final double value)
        {
            this.data = data;
            this.value = value;
        }


        public int compareTo(NeighborEntry t)
        {
            // note that the positions are reversed!
            return Double.compare(t.value, this.value);
        }
    };


    java.util.PriorityQueue> m_Queue;


    int m_Capacity = 0;


    // constructor
    public NearestNeighborList(int capacity)
    {
        m_Capacity = capacity;
        m_Queue = new java.util.PriorityQueue>(m_Capacity);
    }


    public double getMaxPriority()
    {
        NeighborEntry p = m_Queue.peek();
        return (p == null) ? Double.POSITIVE_INFINITY : p.value;
    }


    public boolean insert(T object, double priority)
    {
        if (isCapacityReached())
        {
            if (priority > getMaxPriority())
            {
                // do not insert - all elements in queue have lower priority
                return false;
            }
            m_Queue.add(new NeighborEntry(object, priority));
            // remove object with highest priority
            m_Queue.poll();
        }
        else
        {
            m_Queue.add(new NeighborEntry(object, priority));
        }
        return true;
    }


    public boolean isCapacityReached()
    {
        return m_Queue.size() >= m_Capacity;
    }


    public T getHighest()
    {
        NeighborEntry p = m_Queue.peek();
        return (p == null) ? null : p.data;
    }


    public boolean isEmpty()
    {
        return m_Queue.size() == 0;
    }


    public int getSize()
    {
        return m_Queue.size();
    }


    public T removeHighest()
    {
        // remove object with highest priority
        NeighborEntry p = m_Queue.poll();
        return (p == null) ? null : p.data;
    }
}


如果要实现查找某个点附近5公里的所有点,用kd tree不知道性能会怎么样?
或者如果只是用kd tree的knn算法找到相应的geo hash(按层级),这样是不是性能更好?

你可能感兴趣的:(算法)