CV笔记(一)——KDTree算法的Java实现

版权声明:本文为原创文章,未经博主允许不得用于商业用途。
kd-tree算法的原理参考知乎这篇文章,这里使用java实现了二维kd树。主要代码如下:

class KDTree{

    protected KDNode ROOT;
    protected ArrayList<KDNode> nn;
    protected double[] nnsdist;
    protected int k;
    public KDTree(ArrayList<dPoint> pointlist){
        ROOT = new KDNode();
        BuildTree(pointlist, ROOT, 0, pointlist.size(), true, true);
    }
    public void DFS()
    {
        DFS(ROOT.leftChild);
    }
    public ArrayList<dPoint> KNN(int k, dPoint center){
        this.k = k;
        this.nn = new ArrayList<>();
        this.nnsdist = new double[k];
        KNN(center, ROOT.leftChild);
        ArrayList<dPoint> knn = new ArrayList<>();
        for(KDNode x:nn)
        {
            knn.add(x.dividePoint.Copy());
        }
        return knn;
    }
    protected void KNN(dPoint center, KDNode root)
    {
        if(root==null)
            return;
//        System.out.println(root.dividePoint.x+","+root.dividePoint.y+(root.xaxis?",x,":",y,")+nn.size());
        //if goto left part
        boolean leftpart = root.xaxis && center.x<root.dividePoint.x || !root.xaxis && center.y<root.dividePoint.y;
        if(leftpart)
            KNN(center, root.leftChild);
        else
            KNN(center, root.rightChild);
        double cdist = (center.x-root.dividePoint.x)*(center.x-root.dividePoint.x)+
                (center.y-root.dividePoint.y)*(center.y-root.dividePoint.y);
        if(nn.size()<k)
        {
            nnsdist[nn.size()] = cdist;
            nn.add(root);
        }
        else
        {
            int maxidx = 0;
            for(int i=1;i<k;i++)
                if(nnsdist[i]>nnsdist[maxidx])
                    maxidx = i;
            //if nearer, replace the maximun distance nn point with current point
            if(nnsdist[maxidx]>cdist)
            {
                nn.set(maxidx, root);
                nnsdist[maxidx]=cdist;
            }
            if(root.xaxis && nnsdist[maxidx]<Math.abs(center.x-root.dividePoint.x))
                return;
            if(!root.xaxis && nnsdist[maxidx]<Math.abs(center.y-root.dividePoint.y))
                return;
        }
        if(leftpart)
            KNN(center, root.rightChild);
        else
            KNN(center, root.leftChild);
    }
    protected void DFS(KDNode root)
    {
        System.out.println(root.dividePoint.x+","+root.dividePoint.y);
        if(root.leftChild!=null)
            DFS(root.leftChild);
        if(root.rightChild!=null)
            DFS(root.rightChild);
    }
    protected void BuildTree(ArrayList<dPoint> pointlist,KDNode root , int start, int end, boolean xaxis, boolean left){
        if(start>=end)
            return;
        int mid = start;
        if(start+1<end) {
            if (xaxis)
                pointlist.subList(start, end).sort((dPoint p1, dPoint p2) -> Double.compare(p1.x, p2.x));
            else
                pointlist.subList(start, end).sort((dPoint p1, dPoint p2) -> Double.compare(p1.y, p2.y));
            mid = (start + end) / 2;
        }
        KDNode node = new KDNode();
        node.dividePoint = new dPoint(pointlist.get(mid).x, pointlist.get(mid).y);
        node.xaxis = xaxis;
        if(left)
            root.leftChild = node;
        else
            root.rightChild = node;
        BuildTree(pointlist,node,start,mid,!xaxis,true);
        BuildTree(pointlist,node,mid+1,end,!xaxis,false);
    }
}

可视化效果:
CV笔记(一)——KDTree算法的Java实现_第1张图片

具体代码见github

你可能感兴趣的:(算法原理)