经过这几天研究kd-tree,我可以说kd-tree就是按照基本的思路随便写就可以了吗?
以二维平面为例,在二维平面上有若干点,我们如何建立kd-tree?
第一层以x坐标的中位数将所有点分为两部分,分别放到左右子树上,第二层以y坐标的为中位数再将当前的集合平分,第三层依然以x坐标……如此递归下去。
以查询最近点为例,我们在当前节点上找到查询点应被分在哪一个子树上,优点遍历这个子树,如果在此子树中找到的最有答案小于这个点到分割线的距离,就不需要遍历另一个子树了,据说期望复杂度为O(n^(D-1)/D),然而我并不知道为什么
三个例题写了三个风格完全不同的代码。。。
1、hdu4347
查找k维空间中距离某个点最近的m个点,距离定义为欧几里得距离距离。
kd-tree的基础题,随便写写就可以了。。。。
#include <queue> #include <cstdio> #include <vector> #include <cstring> #include <algorithm> using namespace std; const int Maxn = 100005; int n,K,Q,m,D,i,root; int son[Maxn][2]; struct Point { int x[5]; void read() { for (int i=0;i<K;i++) scanf("%d",&x[i]); } void print() { for (int i=0;i<K-1;i++) printf("%d ",x[i]); printf("%d\n",x[K-1]); } bool operator <(const Point &a)const { return x[D] < a.x[D]; } } dot[Maxn], P; typedef pair<int,int> PR; #define di first #define id second priority_queue <PR> heap; vector <int> ans; int build(int l,int r,int now){ if (l>r) return 0; D = now; int mid = (l+r)>>1; nth_element(dot+l,dot+mid,dot+r+1); //L[mid] = l; R[mid] = r; son[mid][0] = build(l,mid-1,(now+1)%K); son[mid][1] = build(mid+1,r,(now+1)%K); return mid; } #define sqr(x) ((x)*(x)) int Distance(Point P1,Point P2){ int ret = 0; for (int i=0;i<K;i++) ret += sqr(P1.x[i]-P2.x[i]); return ret; } void query(int cur,int now){ if (cur==0) return; PR nd(Distance(dot[cur],P), cur); int x = son[cur][0], y = son[cur][1]; if (dot[cur].x[now]<P.x[now]) swap(x,y); query(x,(now+1)%K); if (heap.size()<m) heap.push(nd); else { if (nd.di<heap.top().di) heap.pop(), heap.push(nd); } if ( sqr(dot[cur].x[now]-P.x[now]) <= heap.top().di ) query(y,(now+1)%K); } int main(){ freopen("hdu4347.in","r",stdin); freopen("hdu4347.out","w",stdout); while (~scanf("%d%d",&n,&K)){ for (i=1;i<=n;i++) dot[i].read(); root = build(1,n,0); scanf("%d",&Q); while (Q--){ P.read(); scanf("%d",&m); printf("the closest %d points are:\n",m); query(root, 0); while (!heap.empty()) { ans.push_back(heap.top().id); heap.pop(); } while (!ans.empty()) { dot[ans.back()].print(); ans.pop_back(); } } } return 0; }
求曼哈顿距离最近点。
看的hzwer的程序,那个calc判断查询点距离那个部分近还是比较精妙的
但是这题暴力插入不要重建就可以过还是比较良心
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int Maxn = 1000005; const int INF = 1e9; int n,m,D,ans,root; int son[Maxn][2]; struct Point { int x[2], mn[2], mx[2]; void read() { scanf("%d%d",&x[0],&x[1]); } bool operator <(const Point &a)const { return x[D]<a.x[D]; } } dot[Maxn], P; int Dis(Point P0,Point P1) { return abs(P0.x[0]-P1.x[0]) + abs(P0.x[1]-P1.x[1]); } void update(int p) { Point L = dot[son[p][0]]; Point R = dot[son[p][1]]; for (int i=0;i<2;i++){ if (son[p][0]) { dot[p].mn[i]=min(dot[p].mn[i], L.mn[i]); dot[p].mx[i]=max(dot[p].mx[i], L.mx[i]); } if (son[p][1]) { dot[p].mn[i]=min(dot[p].mn[i], R.mn[i]); dot[p].mx[i]=max(dot[p].mx[i], R.mx[i]); } } } int build(int l,int r,int now) { if (l>r) return 0; D = now; int mid = (l+r)>>1; nth_element(dot+l,dot+mid,dot+r+1); for (int i=0;i<2;i++) dot[mid].mn[i] = dot[mid].mx[i] = dot[mid].x[i]; son[mid][0] = build(l,mid-1,now^1); son[mid][1] = build(mid+1,r,now^1); update(mid); return mid; } void insert(int p,int now) { if (dot[p].x[now]>P.x[now]){ if (son[p][0]) insert(son[p][0],now^1); else { son[p][0]=++n; for (int i=0;i<2;i++) dot[n].x[i] = dot[n].mn[i] = dot[n].mx[i] = P.x[i]; } } else { if (son[p][1]) insert(son[p][1],now^1); else { son[p][1]=++n; for (int i=0;i<2;i++) dot[n].x[i] = dot[n].mn[i] = dot[n].mx[i] = P.x[i]; } } update(p); } int calc(Point Q) { int ret=0; for (int i=0;i<2;i++) { ret+=max(0,Q.mn[i]-P.x[i]); ret+=max(0,P.x[i]-Q.mx[i]); } return ret; } void query(int p,int now) { int dl=INF, dr=INF; ans=min(ans, Dis(dot[p], P)); if (son[p][0]) dl=calc(dot[son[p][0]]); if (son[p][1]) dr=calc(dot[son[p][1]]); if (dl<=dr) { if (ans>dl) query(son[p][0],now^1); if (ans>dr) query(son[p][1],now^1); } else { if (ans>dr) query(son[p][1],now^1); if (ans>dl) query(son[p][0],now^1); } } int main(){ freopen("2648.in","r",stdin); freopen("2648.out","w",stdout); scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) dot[i].read(); root = build(1,n,0); while (m--){ int type; scanf("%d",&type); P.read(); if (type==1) insert(root,0); else { ans = INF; query(root, 0); printf("%d\n",ans); } } return 0; }
出看题目:这不是cdq分治?
然后你会注意到空间只有20M
如果你是卡空间狂魔仍然认为这题可以用cdq水过,你就会发现这题强制在线。。。。
正解是kd-tree!?
对于棋盘上点权的问题可以用kd-tree直接在线维护
不过这个需要重建kd-tree以防止过于整棵树不平衡,可以设一个比例fac,当某棵小子树的大小大于这棵子树的大小的fac倍时,重建整棵子树。
听说有人不重建水过了?
询问不讲大家都懂的。。。。
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int Maxn = 200005; int x1,x2,y1,y2,D; int n,N,st,ans,rt; int son[Maxn][2]; int pos[Maxn]; struct Point { int x[2], mn[2], mx[2], w, sum, size; void read() { scanf("%d%d%d",&x[0],&x[1],&w); x[0]^=ans; x[1]^=ans; w^=ans; for (int i=0;i<2;i++) mn[i] = mx[i] = x[i]; sum = w; size = 1; } bool operator <(const Point &a)const { return x[D] < a.x[D]; } } dot[Maxn]; const double fac = 0.65; bool cmp(const int &a,const int &b) { return dot[a].x[D] < dot[b].x[D]; } void update(int p){ int L = son[p][0], R = son[p][1]; dot[p].sum=dot[p].w; dot[p].size=1; dot[p].sum += dot[L].sum; dot[p].size += dot[L].size; dot[p].sum += dot[R].sum; dot[p].size += dot[R].size; for (int i=0;i<2;i++){ dot[p].mn[i] = dot[p].mx[i] = dot[p].x[i]; if (L) { dot[p].mn[i] = min(dot[p].mn[i], dot[L].mn[i]); dot[p].mx[i] = max(dot[p].mx[i], dot[L].mx[i]); } if (R) { dot[p].mn[i] = min(dot[p].mn[i], dot[R].mn[i]); dot[p].mx[i] = max(dot[p].mx[i], dot[R].mx[i]); } } } int build(int l,int r,int now){ if (l>r) return 0; D = now; int mid = (l+r)>>1; nth_element(pos+l,pos+mid,pos+r+1,cmp); son[pos[mid]][0] = build(l,mid-1,now^1); son[pos[mid]][1] = build(mid+1,r,now^1); update(pos[mid]); return pos[mid]; } void dfs(int p){ pos[++st] = p; if (son[p][0]) dfs(son[p][0]); if (son[p][1]) dfs(son[p][1]); } int ins(int p,int now){ if (!p) return N; int t = (dot[p].x[now]<=dot[N].x[now]); if (dot[son[p][t]].size+1 > (dot[p].size+1)*fac) { pos[st=1]=N; dfs(p); p = build(1,st,now); } else son[p][t] = ins(son[p][t],now^1), update(p); return p; } int query(int p,int now){ if (p==0) return 0; if (dot[p].mx[0]<x1||dot[p].mn[0]>x2) return 0; if (dot[p].mx[1]<y1||dot[p].mn[1]>y2) return 0; if (dot[p].mn[0]>=x1&&dot[p].mx[0]<=x2&&dot[p].mn[1]>=y1&&dot[p].mx[1]<=y2) return dot[p].sum; int ret = 0; if (dot[p].x[0]>=x1&&dot[p].x[0]<=x2&&dot[p].x[1]>=y1&&dot[p].x[1]<=y2) ret += dot[p].w; if (son[p][0]) ret += query(son[p][0],now^1); if (son[p][1]) ret += query(son[p][1],now^1); return ret; } int main(){ freopen("4066.in","r",stdin); freopen("4066.out","w",stdout); scanf("%d",&n); ans = 0; int type; while (~scanf("%d",&type)){ if (type==1){ dot[++N].read(); rt = ins(rt, 0); } else if (type==2){ scanf("%d%d",&x1,&y1); x1 ^= ans; y1 ^= ans; scanf("%d%d",&x2,&y2); x2 ^= ans; y2 ^= ans; ans = query(rt, 0); printf("%d\n",ans); } else break; } return 0; }