算法课程-kd树-作业

题意理解

一个矩形区域有一组点,问如何求某一矩形范围内的点和最靠近某个点p的点。k=2。

问题分析

kd树是统计学习中k近邻的实现。

直观的思路:就是遍历点,看每个点是否在矩形中;计算每个点和p点之间的距离,最短距离的那个点就是要求的点。

kd树思路:具体细节不展开了,通过空间分割将效率降到log级别。

最终得分:90

其他

kd树实现参考了https://github.com/mingyueanyao/algorithms-princeton-coursera/blob/master/Codes%20of%20Programming%20Assignments/part1/pa5-kdtree/KdTree.java

链接

public class PointSET {
    private SET point_set;

    // construct an empty set of points
    public PointSET() {
        point_set = new SET();
    }

    // is the set empty?
    public boolean isEmpty() {
        return point_set.isEmpty();
    }

    // number of points in the set
    public int size() {
        return  point_set.size();
    }

    // add the point to the set (if it is not already in the set)
    public void insert(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        if (!point_set.contains(p)) {
            point_set.add(p);
        }
    }

    // does the set contain point p?
    public boolean contains(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        return point_set.contains(p);
    }

    // draw all points to standard draw
    public void draw() {
        for (Point2D p : point_set) {
            p.draw();
        }
        StdDraw.show();
    }

    // all points that are inside the rectangle (or on the boundary)
    public Iterable range(RectHV rect) {
        if (rect == null) {
            throw new IllegalArgumentException();
        }
        Stack my_stack = new Stack();
        for (Point2D p : point_set) {
            if (rect.contains(p)) {
                my_stack.push(p);
            }
        }
        return my_stack;
    }

    // a nearest neighbor in the set to point p; null if the set is empty
    public Point2D nearest(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        Point2D nearest = null;
        double min_distance = Double.MAX_VALUE;
        for (Point2D p1 : point_set) {
            double curr_distance =p.distanceTo(p1);
            if (curr_distance < min_distance) {
                nearest = p1;
                min_distance = curr_distance;
            }
        }
        return nearest;
    }
    
    public static void main(String[] args) {

    }
}

kd树实现:

public class KdTree {

    private Node root;
    private int size;

    private static  class Node {
        private Point2D p;
        private Node left;
        private Node right;
        private boolean flag;   // 0: vertical; 1: horizontal

        public Node(Point2D p) {
            this.p = p;
        }
    }
    // construct an empty set of points
    public KdTree() {
        root = null;
        size = 0;
    }

    // is the set empty?
    public boolean isEmpty() {
        return root == null;
    }

    // number of points in the set
    public int size() {
        return  size;
    }

    // add the point to the set (if it is not already in the set)
    public void insert(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        root = insert(root, p, true);  //vertical
    }

    private Node insert(Node h, Point2D p, boolean flag) {
        if (h == null) {
            Node tmp = new Node(p);
            tmp.flag = flag;
            size++;
            return tmp;
        }

        double x = p.x();
        double y = p.y();
        double hx = h.p.x();
        double hy = h.p.y();
        if (h.flag == true) {
            if (x > hx) {
                h.right = insert(h.right, p, false);
            }
            else if (x < hx) {
                h.left = insert(h.left, p, false);
            }
            else if (y != hy) {
                h.right = insert(h.right, p, false);
            }
        }
        if (h.flag == false) {
            if (y > hy) {
                h.right = insert(h.right, p, true);
            }
            else if (y < hy) {
                h.left = insert(h.left, p, true);
            }
            else if (x != hx) {
                h.right = insert(h.right, p, true);
            }
        }
        return h;
    }

    // does the set contain point p?
    public boolean contains(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        Node h = root;
        while  (h != null) {
            if (h.flag == true) {
                if (p.x() > h.p.x()) {
                    h = h.right;
                }
                else if (p.x() < h.p.x()) {
                    h = h.left;
                }
                else if (p.y() != h.p.x()) {
                    h = h.right;
                }
                else {
                    return true;
                }
            }
            else  {
                if (p.y() > h.p.y()) {
                    h = h.right;
                }
                else if (p.y() < h.p.y()) {
                    h = h.left;
                }
                else if (p.x() != h.p.x()) {
                    h = h.right;
                }
                else {
                    return true;
                }
            }
        }
        return false;
    }


    // draw all points to standard draw
    public void draw() {
        draw(root, 0.0, 0.0, 1.0, 1.0);
    }

    private void draw(Node h, double xmin, double ymin, double xmax, double ymax) {
        if (h == null) return;

        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.01);
        h.p.draw(); //draw point

        if (h.flag == true) {
            StdDraw.setPenColor(StdDraw.RED);
            StdDraw.setPenRadius();

            RectHV rect = new RectHV(h.p.x(), ymin, h.p.x(), ymax); // draw vertical line
            rect.draw();
            draw(h.right, h.p.x(), ymin, xmax, ymax);
            draw(h.left, xmin, ymin, h.p.x(), ymax);
        }

        if (h.flag = false) {
            StdDraw.setPenColor(StdDraw.BLUE);
            StdDraw.setPenRadius();

            RectHV rect = new RectHV(xmin, h.p.y(), xmax, h.p.y()); //draw horizontal line
            rect.draw();
            draw(h.right, xmin, h.p.y(), xmax, ymax);
            draw(h.left, xmin, ymin, xmax, h.p.y());
        }
    }

    // all points that are inside the rectangle (or on the boundary)
    public Iterable range(RectHV rect) {
        if (rect == null) {
            throw new IllegalArgumentException();
        }
        Stack my_stack = new Stack();
        RectHV rootRect = new RectHV(0.0, 0.0, 1.0, 1.0);
        range(root, rootRect, rect, my_stack);
        return my_stack;
    }

    private void range(Node h, RectHV hRect, RectHV queryRect, Stack pointsInRect) {
        if (h == null) {
            return;
        }
        if (!hRect.intersects(queryRect)) {
            return;   //区域没有交集
        }

        if(queryRect.contains(h.p)) {
            pointsInRect.push(h.p);
        }

        if (h.flag == true) {   // vertical
            double ymin = hRect.ymin();
            double ymax = hRect.ymax();

            double xmin = h.p.x();
            double xmax = hRect.xmax();
            range(h.right, new RectHV(xmin,ymin, xmax, ymax), queryRect, pointsInRect);
            xmin = hRect.xmin();
            xmax = h.p.x();
            range(h.left, new RectHV(xmin, ymin, xmax, ymax), queryRect, pointsInRect);
        }
        if (h.flag == false) {  //horizontal
            double xmin = hRect.xmin();
            double xmax = hRect.xmax();

            double ymin = h.p.y();
            double ymax = hRect.ymax();
            range(h.right, new RectHV(xmin, ymin, xmax, ymax), queryRect, pointsInRect);

            ymin = hRect.ymin();
            ymax = h.p.y();
            range(h.left, new RectHV(xmin, ymin, xmax, ymax), queryRect, pointsInRect);
        }
    }
    // a nearest neighbor in the set to point p; null if the set is empty
    public Point2D nearest(Point2D p) {
        if (p == null) {
            throw new IllegalArgumentException();
        }
        if (isEmpty()) return null;
        Node nearestN = new Node(root.p);
        nearestN.left = root.left;
        nearestN.right = root.right;
        nearestN.flag = root.flag;
        RectHV rootRect = new RectHV(0.0, 0.0, 1.0, 1.0);
        nearest(root, rootRect, nearestN, p);
        return nearestN.p;
    }

    private void nearest(Node h, RectHV hRect, Node nearest, Point2D queryP) {
        if (h == null) return;

        //当前节点比最近点更近,最近点切换到当前节点
        if (queryP.distanceSquaredTo(h.p) < queryP.distanceSquaredTo(nearest.p)) {
            nearest.p = h.p;
        }

        double hx = h.p.x();
        double hy = h.p.y();
        double x = queryP.x();
        double y = queryP.y();
        double xmin, xmax, ymin, ymax;
        if (h.flag == true) {
            ymin = hRect.ymin();
            ymax = hRect.ymax();

            xmin = hx;
            xmax = hRect.xmax();
            RectHV rightRect = new RectHV(xmin, ymin, xmax, ymax);  //计算右半矩阵

            xmin = hRect.xmin();
            xmax = hx;
            RectHV leftRect = new RectHV(xmin, ymin, xmax, ymax);   //计算左半矩阵

            if (x >= hx) {
                //找到右半边最近的点
                nearest(h.right, rightRect, nearest, queryP);
                //判断左半边会不会有更近的点
                if (leftRect.distanceSquaredTo(queryP) < queryP.distanceSquaredTo(nearest.p)) {
                    nearest(h.left, leftRect, nearest, queryP);
                }
            } else {
                nearest(h.left, leftRect, nearest, queryP);
                if (rightRect.distanceSquaredTo(queryP) < queryP.distanceSquaredTo(nearest.p)) {
                    nearest(h.right, rightRect, nearest, queryP);
                }
            }
        } else {    //horizontal
            xmin = hRect.xmin();
            xmax = hRect.xmax();

            ymin = hy;
            ymax = hRect.ymax();
            RectHV rightRect = new RectHV(xmin, ymin, xmax, ymax);

            ymin = hRect.ymin();
            ymax = hy;
            RectHV leftRect = new RectHV(xmin, ymin, xmax, ymax);

            if (y >= hy) {
                nearest(h.right, rightRect, nearest, queryP);
                if (leftRect.distanceSquaredTo(queryP) < queryP.distanceSquaredTo(nearest.p)) {
                    nearest(h.left, leftRect, nearest, queryP);
                }
            }
            else {
                nearest(h.left, leftRect, nearest, queryP);
                if (rightRect.distanceSquaredTo(queryP) < queryP.distanceSquaredTo(nearest.p)) {
                    nearest(h.right, rightRect, nearest, queryP);
                }
            }
        }

    }
    public static void main(String[] args) {

    }
}

 

你可能感兴趣的:(算法)