题意理解
一个矩形区域有一组点,问如何求某一矩形范围内的点和最靠近某个点p的点。k=2。
问题分析
kd树是统计学习中k近邻的实现。
直观的思路:就是遍历点,看每个点是否在矩形中;计算每个点和p点之间的距离,最短距离的那个点就是要求的点。
kd树思路:具体细节不展开了,通过空间分割将效率降到log级别。
最终得分:90
其他
kd树实现参考了https://github.com/mingyueanyao/algorithms-princeton-coursera/blob/master/Codes%20of%20Programming%20Assignments/part1/pa5-kdtree/KdTree.java
链接
public class PointSET {
private SET point_set;
// construct an empty set of points
public PointSET() {
point_set = new SET();
}
// is the set empty?
public boolean isEmpty() {
return point_set.isEmpty();
}
// number of points in the set
public int size() {
return point_set.size();
}
// add the point to the set (if it is not already in the set)
public void insert(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
if (!point_set.contains(p)) {
point_set.add(p);
}
}
// does the set contain point p?
public boolean contains(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
return point_set.contains(p);
}
// draw all points to standard draw
public void draw() {
for (Point2D p : point_set) {
p.draw();
}
StdDraw.show();
}
// all points that are inside the rectangle (or on the boundary)
public Iterable range(RectHV rect) {
if (rect == null) {
throw new IllegalArgumentException();
}
Stack my_stack = new Stack();
for (Point2D p : point_set) {
if (rect.contains(p)) {
my_stack.push(p);
}
}
return my_stack;
}
// a nearest neighbor in the set to point p; null if the set is empty
public Point2D nearest(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
Point2D nearest = null;
double min_distance = Double.MAX_VALUE;
for (Point2D p1 : point_set) {
double curr_distance =p.distanceTo(p1);
if (curr_distance < min_distance) {
nearest = p1;
min_distance = curr_distance;
}
}
return nearest;
}
public static void main(String[] args) {
}
}
kd树实现:
public class KdTree {
private Node root;
private int size;
private static class Node {
private Point2D p;
private Node left;
private Node right;
private boolean flag; // 0: vertical; 1: horizontal
public Node(Point2D p) {
this.p = p;
}
}
// construct an empty set of points
public KdTree() {
root = null;
size = 0;
}
// is the set empty?
public boolean isEmpty() {
return root == null;
}
// number of points in the set
public int size() {
return size;
}
// add the point to the set (if it is not already in the set)
public void insert(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
root = insert(root, p, true); //vertical
}
private Node insert(Node h, Point2D p, boolean flag) {
if (h == null) {
Node tmp = new Node(p);
tmp.flag = flag;
size++;
return tmp;
}
double x = p.x();
double y = p.y();
double hx = h.p.x();
double hy = h.p.y();
if (h.flag == true) {
if (x > hx) {
h.right = insert(h.right, p, false);
}
else if (x < hx) {
h.left = insert(h.left, p, false);
}
else if (y != hy) {
h.right = insert(h.right, p, false);
}
}
if (h.flag == false) {
if (y > hy) {
h.right = insert(h.right, p, true);
}
else if (y < hy) {
h.left = insert(h.left, p, true);
}
else if (x != hx) {
h.right = insert(h.right, p, true);
}
}
return h;
}
// does the set contain point p?
public boolean contains(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
Node h = root;
while (h != null) {
if (h.flag == true) {
if (p.x() > h.p.x()) {
h = h.right;
}
else if (p.x() < h.p.x()) {
h = h.left;
}
else if (p.y() != h.p.x()) {
h = h.right;
}
else {
return true;
}
}
else {
if (p.y() > h.p.y()) {
h = h.right;
}
else if (p.y() < h.p.y()) {
h = h.left;
}
else if (p.x() != h.p.x()) {
h = h.right;
}
else {
return true;
}
}
}
return false;
}
// draw all points to standard draw
public void draw() {
draw(root, 0.0, 0.0, 1.0, 1.0);
}
private void draw(Node h, double xmin, double ymin, double xmax, double ymax) {
if (h == null) return;
StdDraw.setPenColor(StdDraw.BLACK);
StdDraw.setPenRadius(0.01);
h.p.draw(); //draw point
if (h.flag == true) {
StdDraw.setPenColor(StdDraw.RED);
StdDraw.setPenRadius();
RectHV rect = new RectHV(h.p.x(), ymin, h.p.x(), ymax); // draw vertical line
rect.draw();
draw(h.right, h.p.x(), ymin, xmax, ymax);
draw(h.left, xmin, ymin, h.p.x(), ymax);
}
if (h.flag = false) {
StdDraw.setPenColor(StdDraw.BLUE);
StdDraw.setPenRadius();
RectHV rect = new RectHV(xmin, h.p.y(), xmax, h.p.y()); //draw horizontal line
rect.draw();
draw(h.right, xmin, h.p.y(), xmax, ymax);
draw(h.left, xmin, ymin, xmax, h.p.y());
}
}
// all points that are inside the rectangle (or on the boundary)
public Iterable range(RectHV rect) {
if (rect == null) {
throw new IllegalArgumentException();
}
Stack my_stack = new Stack();
RectHV rootRect = new RectHV(0.0, 0.0, 1.0, 1.0);
range(root, rootRect, rect, my_stack);
return my_stack;
}
private void range(Node h, RectHV hRect, RectHV queryRect, Stack pointsInRect) {
if (h == null) {
return;
}
if (!hRect.intersects(queryRect)) {
return; //区域没有交集
}
if(queryRect.contains(h.p)) {
pointsInRect.push(h.p);
}
if (h.flag == true) { // vertical
double ymin = hRect.ymin();
double ymax = hRect.ymax();
double xmin = h.p.x();
double xmax = hRect.xmax();
range(h.right, new RectHV(xmin,ymin, xmax, ymax), queryRect, pointsInRect);
xmin = hRect.xmin();
xmax = h.p.x();
range(h.left, new RectHV(xmin, ymin, xmax, ymax), queryRect, pointsInRect);
}
if (h.flag == false) { //horizontal
double xmin = hRect.xmin();
double xmax = hRect.xmax();
double ymin = h.p.y();
double ymax = hRect.ymax();
range(h.right, new RectHV(xmin, ymin, xmax, ymax), queryRect, pointsInRect);
ymin = hRect.ymin();
ymax = h.p.y();
range(h.left, new RectHV(xmin, ymin, xmax, ymax), queryRect, pointsInRect);
}
}
// a nearest neighbor in the set to point p; null if the set is empty
public Point2D nearest(Point2D p) {
if (p == null) {
throw new IllegalArgumentException();
}
if (isEmpty()) return null;
Node nearestN = new Node(root.p);
nearestN.left = root.left;
nearestN.right = root.right;
nearestN.flag = root.flag;
RectHV rootRect = new RectHV(0.0, 0.0, 1.0, 1.0);
nearest(root, rootRect, nearestN, p);
return nearestN.p;
}
private void nearest(Node h, RectHV hRect, Node nearest, Point2D queryP) {
if (h == null) return;
//当前节点比最近点更近,最近点切换到当前节点
if (queryP.distanceSquaredTo(h.p) < queryP.distanceSquaredTo(nearest.p)) {
nearest.p = h.p;
}
double hx = h.p.x();
double hy = h.p.y();
double x = queryP.x();
double y = queryP.y();
double xmin, xmax, ymin, ymax;
if (h.flag == true) {
ymin = hRect.ymin();
ymax = hRect.ymax();
xmin = hx;
xmax = hRect.xmax();
RectHV rightRect = new RectHV(xmin, ymin, xmax, ymax); //计算右半矩阵
xmin = hRect.xmin();
xmax = hx;
RectHV leftRect = new RectHV(xmin, ymin, xmax, ymax); //计算左半矩阵
if (x >= hx) {
//找到右半边最近的点
nearest(h.right, rightRect, nearest, queryP);
//判断左半边会不会有更近的点
if (leftRect.distanceSquaredTo(queryP) < queryP.distanceSquaredTo(nearest.p)) {
nearest(h.left, leftRect, nearest, queryP);
}
} else {
nearest(h.left, leftRect, nearest, queryP);
if (rightRect.distanceSquaredTo(queryP) < queryP.distanceSquaredTo(nearest.p)) {
nearest(h.right, rightRect, nearest, queryP);
}
}
} else { //horizontal
xmin = hRect.xmin();
xmax = hRect.xmax();
ymin = hy;
ymax = hRect.ymax();
RectHV rightRect = new RectHV(xmin, ymin, xmax, ymax);
ymin = hRect.ymin();
ymax = hy;
RectHV leftRect = new RectHV(xmin, ymin, xmax, ymax);
if (y >= hy) {
nearest(h.right, rightRect, nearest, queryP);
if (leftRect.distanceSquaredTo(queryP) < queryP.distanceSquaredTo(nearest.p)) {
nearest(h.left, leftRect, nearest, queryP);
}
}
else {
nearest(h.left, leftRect, nearest, queryP);
if (rightRect.distanceSquaredTo(queryP) < queryP.distanceSquaredTo(nearest.p)) {
nearest(h.right, rightRect, nearest, queryP);
}
}
}
}
public static void main(String[] args) {
}
}