感谢orz神·小黑的指导
kd-tree就是用来计算若干维空间k近/远点的数(shou)据(suo)结(you)构(hua)
建树
假设题目是k维的点
第deep层就是用deep%k+1维把所有点分为两块
取deep%k+1维中位数的点做为当前子树的根节点
再把该维比这个点小的点扔到左子树 比这个点大的扔到右子树 递归处理
详见代码
1 void Sort(ll l,ll r,ll k){ cmpp=k; sort(kd+l,kd+r+1,cmp); } 2 ll build(ll l,ll r,ll deep){ 3 if (l==r){ 4 kd[l].lc=kd[r].rc=0; 5 return l; 6 } 7 Sort(l,r,deep); 8 ll mid=(l+r)/2; 9 if (l<mid) kd[mid].lc=build(l,mid-1,deep%k+1); 10 else kd[mid].lc=0; 11 if (mid<r) kd[mid].rc=build(mid+1,r,deep%k+1); 12 else kd[mid].rc=0; 13 return mid; 14 }
查询
询问离点S的前m近点 说它是搜索优化就是因为这里- -
维护大根堆记录答案 当元素个数小于m时直接push
反正判断有木有 比最大值小 有就pop再push
当搜索当t点是先用该点到S的距离维护堆
再判断如果S的deep%k+1维 比t点该维小就先搜索左子树 否则搜索右子树
搜索完一颗子树后 判断如果S到t点deep%k+1维的距离就≥ans显然继续搜索没用 就不继续搜索 否则搜索另一颗子树
代码
1 void push(ll t){ 2 ll dis=getdis(S,poi[t]); 3 if (size==m){ 4 if (dis>que.top().dis) return; 5 else{ 6 que.pop(); 7 que.push(info(dis,t)); 8 } 9 }else{ 10 ++size; 11 que.push(info(dis,t)); 12 } 13 } 14 void makeans(ll t,ll deep){ 15 if (!t) return; 16 push(kd[t].t); 17 if (S.d[deep]<=kd[t].p.d[deep]){ 18 makeans(kd[t].lc,deep%k+1); 19 if (size<m || que.top().dis>sqr(S.d[deep]-kd[t].p.d[deep])) makeans(kd[t].rc,deep%k+1); 20 }else{ 21 makeans(kd[t].rc,deep%k+1); 22 if (size<m || que.top().dis>sqr(S.d[deep]-kd[t].p.d[deep])) makeans(kd[t].lc,deep%k+1); 23 } 24 }
最远点
这里讲的都是m近点- -
如果是m远点其实是差不多的 只是维护的东西不太一样
需要维护每维的min和max
询问的时候基本同理yy下即可
求k位距离S的m近点代码
1 #include <cstdio> 2 #include <algorithm> 3 #include <queue> 4 typedef long long ll; 5 using namespace std; 6 const ll N=50001; 7 struct inpo{ 8 ll d[6]; 9 }poi[N],S,ans[11]; 10 struct inkd{ 11 ll t,lc,rc; 12 inpo p; 13 inkd(const ll a=0,const ll b=0,const ll c=0): 14 t(a),lc(b),rc(c){} 15 }kd[N]; 16 struct info{ 17 ll dis,t; 18 info(const ll a=0,const ll b=0): 19 dis(a),t(b){} 20 }; 21 priority_queue <info> que; 22 ll root,n,k,m,t,cmpp,size; 23 inline bool operator <(info a,info b){ return a.dis<b.dis; } 24 inline bool cmp(inkd a,inkd b){ return a.p.d[cmpp]<b.p.d[cmpp]; } 25 void Sort(ll l,ll r,ll k){ cmpp=k; sort(kd+l,kd+r+1,cmp); } 26 ll sqr(ll x){ return x*x; } 27 ll getdis(inpo a,inpo b){ 28 ll res=0; 29 for (ll i=1;i<=k;i++) res+=sqr(a.d[i]-b.d[i]); 30 return res; 31 } 32 ll build(ll l,ll r,ll deep){ 33 if (l==r){ 34 kd[l].lc=kd[r].rc=0; 35 return l; 36 } 37 Sort(l,r,deep); 38 ll mid=(l+r)/2; 39 if (l<mid) kd[mid].lc=build(l,mid-1,deep%k+1); 40 else kd[mid].lc=0; 41 if (mid<r) kd[mid].rc=build(mid+1,r,deep%k+1); 42 else kd[mid].rc=0; 43 return mid; 44 } 45 void push(ll t){ 46 ll dis=getdis(S,poi[t]); 47 if (size==m){ 48 if (dis>que.top().dis) return; 49 else{ 50 que.pop(); 51 que.push(info(dis,t)); 52 } 53 }else{ 54 ++size; 55 que.push(info(dis,t)); 56 } 57 } 58 void makeans(ll t,ll deep){ 59 if (!t) return; 60 push(kd[t].t); 61 if (S.d[deep]<=kd[t].p.d[deep]){ 62 makeans(kd[t].lc,deep%k+1); 63 if (size<m || que.top().dis>sqr(S.d[deep]-kd[t].p.d[deep])) makeans(kd[t].rc,deep%k+1); 64 }else{ 65 makeans(kd[t].rc,deep%k+1); 66 if (size<m || que.top().dis>sqr(S.d[deep]-kd[t].p.d[deep])) makeans(kd[t].lc,deep%k+1); 67 } 68 } 69 int main(){ 70 freopen("hdu4347.in","r",stdin); 71 freopen("hdu4347.out","w",stdout); 72 while (~scanf("%I64d%I64d",&n,&k)){ 73 for (ll i=1;i<=n;i++){ 74 for (ll j=1;j<=k;j++) scanf("%I64d",&poi[i].d[j]); 75 kd[i].t=i,kd[i].p=poi[i]; 76 } 77 root=build(1,n,1); 78 scanf("%I64d",&t); 79 for (;t;t--){ 80 size=0; 81 for (ll i=1;i<=k;i++) scanf("%I64d",&S.d[i]); 82 scanf("%I64d",&m); 83 printf("the closest %I64d points are:\n",m); 84 makeans(root,1); 85 for (ll i=1;i<=m;i++){ 86 ans[i]=poi[que.top().t]; 87 que.pop(); 88 } 89 for (ll i=m;i;i--){ 90 for (ll j=1;j<=k;j++){ 91 printf("%I64d",ans[i].d[j]); 92 if (j<k) printf(" "); 93 } 94 puts(""); 95 } 96 } 97 } 98 fclose(stdin); 99 fclose(stdout); 100 }