【java】KDTree,实现个java版本,留着日后可能用得上

Java版本KDTree


在KDTree中,只有n >> 2 ^ xn时,在明显得有KDTCount << n,

n是点的个数,

xn是点的维数

KDTCount是在KDTree搜索时计算距离的次数统计


package main;

import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

public class KDTreeMain {
	public static int KDTCount = 0; // 统计在kdt 搜索的时候,计算了和几个点的距离
	public static void main(String[] args) {
		/*
		 *  n >> 2^xn 时, KDTCount才明显 < n
		 *  
		 *	=========================
		 *	n = 40000, xn = 10
			buld kdt time = 30760.0
			
			query kdt time = 1404.0
			best  = 0.3296984744447501
			KDTCount = 4488
			
			query brute time = 3317.0
			best2 = 0.3296984744447501
	
			==========================
			n = 50000, xn = 10
			buld kdt time = 49664.0
			
			query kdt time = 558.0
			best  = 0.3355435846472523
			KDTCount = 2056
			
			query brute time = 5557.0
			best2 = 0.3355435846472523
			
			==========================
			n = 50000. xn = 20
			buld kdt time = 63560.0
			
			query kdt time = 15136.0
			best  = 0.8319764077450744
			KDTCount = 37500
			
			query brute time = 5791.0
			best2 = 0.8319764077450744

		 **/
		int n = 50000;  // 样本点个数
		int xn = 10;	// 样本点维数
		int deep = 0;	// 轴
	
		// 随机生成训练样本数据
		List<Point> pointList = new LinkedList<Point>();
		for (int i = 0; i < n; i++) {
			double[] d = new double[xn];
			for (int j = 0; j < d.length; j++) {
				d[j] = Math.random();
			}
			pointList.add(new Point(d));
		}
		
		// build tree
		System.out.println("beging insert...");
		double t1 = System.currentTimeMillis();
		KDTreeMain kdt = new KDTreeMain();
		Node root = new Node();
		kdt.insert(root, pointList, deep);
		double t2 = System.currentTimeMillis();
		System.out.println("buld kdt time = " + (t2 - t1));
		
		// show tree
//		char[] path = new char[30];
//		int pi = 0;
//		showKDTree(root, path, pi);
		
		// 目标点
		double[] f = new double[xn];
		for (int j = 0; j < f.length; j++) {
			f[j] = Math.random();
		}
		Point p = new Point(f);

		// KDT搜索
		double t3 = System.currentTimeMillis();
		double best = Double.MAX_VALUE;
		best = query(root, p, best, deep);
		double t4 = System.currentTimeMillis();
		System.out.println("\nquery kdt time = " + (t4 - t3));
		System.out.println("best  = " + best);
		System.out.println("KDTCount = " + KDTCount);
		
		// 暴力法
		double t5 = System.currentTimeMillis();
		int index = 0;
		double best2 = Double.MAX_VALUE;
		for (int i = 0; i < n; i++) {
			double dist = getDist(p, pointList.get(i));
			if (dist < best2) {
				best2 = dist;
				index = i;
			}
		}
		double t6 = System.currentTimeMillis();
		System.out.println("\nquery brute time = " + (t6 - t5));
		System.out.println("best2 = " + best2);
		// System.out.println("goal point = " + p.x[0] + " , " + p.x[1]);
		// System.out.println("neast point = " + pointList.get(index).x[0] + " , " + pointList.get(index).x[1]);
	}
	
	// build kdtree
	private void insert(Node root, List<Point> pointList, int deep) {
		int mid = pointList.size() / 2;
		
		// 排序后拿到中位数
		Point.deep = deep;
		Collections.sort(pointList);
		
		// 类似快排的方法拿到中位数
		// getMedian(pointList, 0, pointList.size() - 1, mid, deep);
		// showList(pointList);
		// System.out.println("=========================");
		int pl = mid;
		int pr = mid;
		while(pl >= 0 && pointList.get(pl).x[deep] == pointList.get(mid).x[deep]) pl--;
		while(pr < pointList.size() && pointList.get(pr).x[deep] == pointList.get(mid).x[deep]) pr++;
		List<Point> pointListLeft = pointList.subList(0, pl + 1);
		List<Point> pointListMid = pointList.subList(pl + 1, pr);
		List<Point> pointListRight = pointList.subList(pr, pointList.size());
		
		root.pointList = pointListMid;
		if (pointListLeft.size() > 0) {
			root.l = new Node();
			insert(root.l, pointListLeft, (deep + 1) % pointList.get(0).x.length);
		}
		if (pointListRight.size() > 0) {
			root.r = new Node();
			insert(root.r, pointListRight, (deep + 1) % pointList.get(0).x.length);
		}
		
	}
	
	// search the nearest point to p in KDTree
	private static double query(Node root, Point p, double best, int deep) {
		if (root == null) return Double.MAX_VALUE;  
	    double dist;  
	    if (root.l == null && root.r == null) {  
	        for (int i = 0; i < root.pointList.size(); i++) {  
	            KDTCount++;  
	            dist = getDist(root.pointList.get(i), p);  
	            best = dist < best ? dist : best;  
	        }  
	        return best;  
	    }  
	  
	    // left or right  
	    if (p.x[deep] <= root.pointList.get(0).x[deep]) {  
	        best = query(root.l, p, best, (deep + 1) % p.x.length);
	    } else {  
	        best = query(root.r, p, best, (deep + 1) % p.x.length);
	    }  
	    // cur  
	    for (int i = 0; i < root.pointList.size(); i++) {  
	        KDTCount++;  
	        dist = getDist(root.pointList.get(i), p);  
	        best = dist < best ? dist : best;  
	    }  
	    // another side  
	    if (best >= Math.abs(p.x[deep] - root.pointList.get(0).x[deep])) {  
	        double distAnother = Double.MAX_VALUE;  
	        if (p.x[deep] <= root.pointList.get(0).x[deep]) {  
	            distAnother = query(root.r, p, best, (deep + 1) % p.x.length);
	        } else {  
	            distAnother = query(root.l, p, best, (deep + 1) % p.x.length);
	        }  
	        if (distAnother < best) {  
	            best = distAnother;  
	        }  
	    }  
	    return best;  
	}
	
	// print kdtree
	private static void showKDTree(Node root, char[] path, int pi) {
		if (root == null) return;
		System.out.print(pi + "# ");
		for (int i = 0; i < pi; i++) {
			System.out.print(path[i] + " ");
		}
		// mid
		showList(root.pointList);
		// left
		path[pi++] = 'L';
		showKDTree(root.l, path, pi);
		pi--;
		// right
		path[pi++] = 'R';
		showKDTree(root.r, path, pi);
		pi--;
	}
	// 欧式距离
	private static double getDist(Point p1, Point p2) {
		double sum = 0;
		for (int i = 0; i < p1.x.length; i++) {
			sum += (p1.x[i] - p2.x[i]) * (p1.x[i] - p2.x[i]);
		}
		if (sum == 0) return Double.MAX_VALUE;
		return Math.sqrt(sum);
	}
	// 类似快排的思想拿到中位数,O(n)时间复杂度
	private void getMedian(List<Point> pointList, int l, int r, int k, int deep) {
		if (l == r && k == 0) return;  
	    int pl = l;  
	    int pr = r;  
	    double[] tmp = pointList.get(l).x;  
	    while (pl < pr) {  
	        while (pl < pr && pointList.get(pr).x[deep] > tmp[deep]) pr--;  
	        if (pl >= pr) break;  
	        pointList.get(pl++).x = pointList.get(pr).x;  
	        while (pl < pr && pointList.get(pl).x[deep] < tmp[deep]) pl++;  
	        if (pl >= pr) break;  
	        pointList.get(pr--).x = pointList.get(pl).x;
	    }  
	    pointList.get(pl).x = tmp;  
	  
	    if(pl - l == k) return;  
	    if(pl - l >  k) {  
	        getMedian(pointList, l, pl - 1, k, deep);  
	    } else {  
	        getMedian(pointList, pl + 1, r, k - (pl - l + 1), deep);  
	    }  
	}
	// 打印一个点列表
	private static void showList(List<Point> pointList) {
		for (int i = 0; i < pointList.size(); i++) {
			for( int j = 0; j < pointList.get(i).x.length; j++) {
				System.out.print(pointList.get(i).x[j] + ",");
			}
			System.out.print(" / ");
		}
		System.out.println();
	}
}
// kdtree里的节点
class Node {
	List<Point> pointList = new LinkedList<Point>();
	Node l = null;
	Node r = null;
}
// 数据点
class Point implements Comparable<Point>{
	public static int deep = 0;
	double[] x;
	public Point(double[] d) {
		x = new double[d.length];
		for (int i = 0; i < d.length; i++) {
			x[i] = d[i];
		}
	}
	public int compareTo(Point o) {
		// return (int)(this.x[deep] == other.x[deep]); 出错,因为x的值在0~1之间,那么int都是0了
		Point other = (Point)o;
		if (this.x[deep] == other.x[deep]) return 0;
		if (this.x[deep] >  other.x[deep]) return 1;
		return -1;
	}
}


你可能感兴趣的:(【java】KDTree,实现个java版本,留着日后可能用得上)