题意:首先给出n(n<=50000)个点,每个点最多k(k<=10)维。然后是T组查询,每次查询给定一个点,求距离这个点最近的m个点。按照距离从小到大输出。
思路:KD树,讲解可以参考统计学习教材K近邻那章。我的代码是参照(http://blog.csdn.net/wxfwxf328/article/details/8158187)写的。
首先是建树,难点不多,用到了nth_element函数,正好符合我们的所需:将中间点固定,小于这个点的弄到左边去,大于的弄到右边去。这里的一个技巧是point结构体的小于号重构,里面居然可以用一个全局变量当做变量。
查询的时候一次将整个树遍历一遍(中间有大量剪枝)。
很多博客写的在k维kd树上查询最邻近点的最坏情况复杂度为O(k * N^(1-1/k)).
#include <cstdio> #include <queue> #include <algorithm> #include <cstring> using namespace std; #define INF 0x3fffffff #define clr(s,t) memset(s,t,sizeof(s)); #define N 50005 #define Q 10005 int k,n,m,T,idx,son[N<<2]; struct point{ int s[7]; bool operator<(const point &b)const{ return s[idx]<b.s[idx]; } }p[N],base,kdt[N<<2]; struct node{ point pp; double dis; bool operator<(const node &b)const{ return dis < b.dis; } }res[12]; priority_queue<struct node> q; double dist(point a,point b){ double res = 0; for(int i = 0;i<k;i++) res += (a.s[i]-b.s[i])*(a.s[i]-b.s[i]); return res; } void build(int r,int a,int b,int d){ if(a>b) return; idx = d%k; //此次按照第几个维度来进行二分 int mid = (a+b)>>1; son[r] = b; son[r*2] = son[r*2+1] = -1; //如果两个儿子都没有结点,当前结点当然就是叶节点 nth_element(p+a, p+mid, p+b+1); kdt[r] = p[mid]; build(r*2, a, mid-1, d+1); build(r*2+1, mid+1, b, d+1); } void query(int r,int d){ int x,y,id = d%k,flag = 1; if(son[r] == -1) return; node tmp; //表示当前子树根结点 tmp.pp = kdt[r]; tmp.dis = dist(tmp.pp,base);//根节点到待查结点的距离 x = r*2; y = r*2+1; if(tmp.pp.s[id] < base.s[id])//先去待查结点所在的那一边去查找(为了下文方便,永远是先找x) swap(x,y); query(x, d+1); if(q.size()<m)//如果现在找到的结点还不到m个,那么将这个根节点加进去啦 q.push(tmp); else{ if(q.top().dis > tmp.dis){//根节点到base的距离比目前找到的最大距离要小,那么根结点加进去 q.pop(); q.push(tmp); } if((tmp.pp.s[id]-base.s[id])*(tmp.pp.s[id]-base.s[id]) > q.top().dis)//这一步是剪枝,相当于教材里描述的画那个圆 flag = 0; } if(flag) query(y, d+1); } int main(){ while(scanf("%d %d",&n,&k)!=EOF){ int i,j; idx = 0; for(i = 1;i<=n;i++) for(j = 0;j<k;j++) scanf("%d",&p[i].s[j]); build(1,1,n,0); scanf("%d",&T); while(T--){ for(j = 0;j<k;j++) scanf("%d",&base.s[j]); scanf("%d",&m); query(1,0); for(i = 1;i<=m;i++){ res[i] = q.top(); q.pop(); } printf("the closest %d points are:\n",m); for(i = m;i>=1;i--){ for(j = 0;j<k-1;j++) printf("%d ",res[i].pp.s[j]); printf("%d\n",res[i].pp.s[k-1]); } } } return 0; }