2012 Multi-University Training Contest 5 HDU4347 The Closest M Points
写个kd-tree模板...无它
UPD: 以前写的代码太难看啦,趁去南京赛区之前整理模板重写了一下...
允许转载,转载请注明出处:
UPD: 以前写的代码太难看啦,趁去南京赛区之前整理模板重写了一下...
1
//
hdu4347 KD Tree
2 // 询问k维空间中距某点最近的m个点
3
4 #include < cstdio >
5 #include < queue >
6 #include < algorithm >
7
8 using namespace std;
9
10 const int N = 200010 ;
11 const int DIM = 5 ;
12
13 inline double sqr( double x) { return x * x; }
14
15 int k, n; // k为维数, n为点数
16
17 struct Point {
18 int x[DIM];
19 double cald(Point o) {
20 double ret = 0 ;
21 for ( int i = 0 ; i < k; i ++ ) {
22 ret += sqr(x[i] - o.x[i]);
23 }
24 return ret;
25 }
26 void print() {
27 for ( int i = 0 ; i < k; i ++ ) {
28 printf( " %d " , x[i]);
29 i == k - 1 ? puts( "" ) : printf( " " );
30 }
31 }
32 }point[N];
33
34 struct Heap_t {
35 Point p; double dis;
36 Heap_t() {}
37 Heap_t(Point _p, double _dis) : p(_p), dis(_dis) {}
38 bool operator < ( const Heap_t & o) const {
39 return dis < o.dis;
40 }
41 };
42
43 priority_queue < Heap_t > q; // 维护前m大数
44
45 // 笨方法,用于对动态第dim维排序,通过修改全局变量dim来达到目的
46 int dim;
47 bool cmp(Point a, Point b) {
48 return a.x[dim] < b.x[dim];
49 }
50
51 struct KDTree {
52 struct Node {
53 Point p;
54 int size;
55 }t[N];
56 inline int LC( int x) { return x << 1 ; }
57 inline int RC( int x) { return x << 1 | 1 ; }
58 // 建树,采取wiki上的建树方法
59 void build(Point p[], int l, int r, int rt, int dep) {
60 if (l > r) return ;
61 t[rt].size = r - l;
62 t[LC(rt)].size = t[RC(rt)].size = - 1 ;
63 dim = dep % k;
64 int m = (l + r) >> 1 ;
65 nth_element(p + l, p + m, p + r + 1 , cmp);
66 t[rt].p = p[m];
67 build(p, l, m - 1 , LC(rt), dep + 1 );
68 build(p, m + 1 , r, RC(rt), dep + 1 );
69 }
70 void insert(Heap_t h) {
71 if (h.dis < q.top().dis) {
72 q.pop(); q.push(h);
73 }
74 }
75 // 询问前m近的点。
76 // 与最近邻相似,先一路搜到叶子,然后如果当前得到的点数<m时,要搜索所有的子树。
77 // 得到m个点之后就维护一个大小为m的堆,当前节点距离<堆顶元素距离时,将当前节点加入,堆顶元素弹出。
78 // 其余与最近邻询问相似。
79 void query(Point p, int rt, int dep, int m) {
80 if (t[rt].size == - 1 ) return ;
81 Heap_t h(t[rt].p, t[rt].p.cald(p));
82 int dim = dep % k;
83 if (p.x[dim] < t[rt].p.x[dim]) {
84 query(p, LC(rt), dep + 1 , m);
85 if (q.size() < m) {
86 q.push(h);
87 query(p, RC(rt), dep + 1 , m);
88 } else {
89 insert(h);
90 // 如果要查询的点与超平面的距离 < 堆顶元素的距离,则要到另一边超平面去查询
91 if (sqr(p.x[dim] - t[rt].p.x[dim]) < q.top().dis) {
92 query(p, RC(rt), dep + 1 , m);
93 }
94 }
95 } else {
96 query(p, RC(rt), dep + 1 , m);
97 if (q.size() < m) {
98 q.push(h);
99 query(p, LC(rt), dep + 1 , m);
100 } else {
101 insert(h);
102 if (sqr(p.x[dim] - t[rt].p.x[dim]) < q.top().dis) {
103 query(p, LC(rt), dep + 1 , m);
104 }
105 }
106 }
107 }
108 }kdt;
109
110 int main() {
111 while ( ~ scanf( " %d%d " , & n, & k)) {
112 for ( int i = 0 ; i < n; i ++ ) {
113 for ( int j = 0 ; j < k; j ++ ) {
114 scanf( " %d " , & point[i].x[j]);
115 }
116 }
117 kdt.build(point, 0 , n - 1 , 1 , 0 );
118 int t; scanf( " %d " , & t);
119 for ( int i = 0 ; i < t; i ++ ) {
120 Point ask;
121 for ( int j = 0 ; j < k; j ++ ) {
122 scanf( " %d " , & ask.x[j]);
123 }
124 int m; scanf( " %d " , & m);
125 kdt.query(ask, 1 , 0 , m);
126 Point p[ 10 ];
127 for ( int j = 0 ; ! q.empty(); j ++ ) {
128 p[j] = q.top().p; q.pop();
129 }
130 printf( " the closest %d points are:\n " , m);
131 for ( int j = m - 1 ; j >= 0 ; j -- ) {
132 p[j].print();
133 }
134 }
135 }
136 return 0 ;
137 }
2 // 询问k维空间中距某点最近的m个点
3
4 #include < cstdio >
5 #include < queue >
6 #include < algorithm >
7
8 using namespace std;
9
10 const int N = 200010 ;
11 const int DIM = 5 ;
12
13 inline double sqr( double x) { return x * x; }
14
15 int k, n; // k为维数, n为点数
16
17 struct Point {
18 int x[DIM];
19 double cald(Point o) {
20 double ret = 0 ;
21 for ( int i = 0 ; i < k; i ++ ) {
22 ret += sqr(x[i] - o.x[i]);
23 }
24 return ret;
25 }
26 void print() {
27 for ( int i = 0 ; i < k; i ++ ) {
28 printf( " %d " , x[i]);
29 i == k - 1 ? puts( "" ) : printf( " " );
30 }
31 }
32 }point[N];
33
34 struct Heap_t {
35 Point p; double dis;
36 Heap_t() {}
37 Heap_t(Point _p, double _dis) : p(_p), dis(_dis) {}
38 bool operator < ( const Heap_t & o) const {
39 return dis < o.dis;
40 }
41 };
42
43 priority_queue < Heap_t > q; // 维护前m大数
44
45 // 笨方法,用于对动态第dim维排序,通过修改全局变量dim来达到目的
46 int dim;
47 bool cmp(Point a, Point b) {
48 return a.x[dim] < b.x[dim];
49 }
50
51 struct KDTree {
52 struct Node {
53 Point p;
54 int size;
55 }t[N];
56 inline int LC( int x) { return x << 1 ; }
57 inline int RC( int x) { return x << 1 | 1 ; }
58 // 建树,采取wiki上的建树方法
59 void build(Point p[], int l, int r, int rt, int dep) {
60 if (l > r) return ;
61 t[rt].size = r - l;
62 t[LC(rt)].size = t[RC(rt)].size = - 1 ;
63 dim = dep % k;
64 int m = (l + r) >> 1 ;
65 nth_element(p + l, p + m, p + r + 1 , cmp);
66 t[rt].p = p[m];
67 build(p, l, m - 1 , LC(rt), dep + 1 );
68 build(p, m + 1 , r, RC(rt), dep + 1 );
69 }
70 void insert(Heap_t h) {
71 if (h.dis < q.top().dis) {
72 q.pop(); q.push(h);
73 }
74 }
75 // 询问前m近的点。
76 // 与最近邻相似,先一路搜到叶子,然后如果当前得到的点数<m时,要搜索所有的子树。
77 // 得到m个点之后就维护一个大小为m的堆,当前节点距离<堆顶元素距离时,将当前节点加入,堆顶元素弹出。
78 // 其余与最近邻询问相似。
79 void query(Point p, int rt, int dep, int m) {
80 if (t[rt].size == - 1 ) return ;
81 Heap_t h(t[rt].p, t[rt].p.cald(p));
82 int dim = dep % k;
83 if (p.x[dim] < t[rt].p.x[dim]) {
84 query(p, LC(rt), dep + 1 , m);
85 if (q.size() < m) {
86 q.push(h);
87 query(p, RC(rt), dep + 1 , m);
88 } else {
89 insert(h);
90 // 如果要查询的点与超平面的距离 < 堆顶元素的距离,则要到另一边超平面去查询
91 if (sqr(p.x[dim] - t[rt].p.x[dim]) < q.top().dis) {
92 query(p, RC(rt), dep + 1 , m);
93 }
94 }
95 } else {
96 query(p, RC(rt), dep + 1 , m);
97 if (q.size() < m) {
98 q.push(h);
99 query(p, LC(rt), dep + 1 , m);
100 } else {
101 insert(h);
102 if (sqr(p.x[dim] - t[rt].p.x[dim]) < q.top().dis) {
103 query(p, LC(rt), dep + 1 , m);
104 }
105 }
106 }
107 }
108 }kdt;
109
110 int main() {
111 while ( ~ scanf( " %d%d " , & n, & k)) {
112 for ( int i = 0 ; i < n; i ++ ) {
113 for ( int j = 0 ; j < k; j ++ ) {
114 scanf( " %d " , & point[i].x[j]);
115 }
116 }
117 kdt.build(point, 0 , n - 1 , 1 , 0 );
118 int t; scanf( " %d " , & t);
119 for ( int i = 0 ; i < t; i ++ ) {
120 Point ask;
121 for ( int j = 0 ; j < k; j ++ ) {
122 scanf( " %d " , & ask.x[j]);
123 }
124 int m; scanf( " %d " , & m);
125 kdt.query(ask, 1 , 0 , m);
126 Point p[ 10 ];
127 for ( int j = 0 ; ! q.empty(); j ++ ) {
128 p[j] = q.top().p; q.pop();
129 }
130 printf( " the closest %d points are:\n " , m);
131 for ( int j = m - 1 ; j >= 0 ; j -- ) {
132 p[j].print();
133 }
134 }
135 }
136 return 0 ;
137 }
允许转载,转载请注明出处:
http://www.blogjava.net/lkjslkjdlk/archive/2012/08/13/385426.html