【KD-TREE介绍】在SYC1999大神的“蛊惑”下,我开始接触这种算法。
首先,大概的概念可以去百度百科。具体实现,我是看RZZ的代码长大的。
我们可以想象在平面上有N个点。首先,按横坐标排序找到最中间的那个点。然后水平划一条线,把平面分成左右两个部分。再递归调用左右两块。注意,在第二次(偶数次)调用的时候,是找到纵坐标中最中间的点,并垂直画一条线。
这样效率看上去很好。维护的时候有点像线段树。每个点记录它的坐标、它辖管的区间4个方向的极值、它的左右(或上下)的两个点的标号。递归两个子树时,注意要up更新这个点辖管的范围。
inline int cmp(arr a,arr b){return a.d[D]<b.d[D]||a.d[D]==b.d[D]&&a.d[D^1]<b.d[D^1];} inline void up(int k,int s) { a[k].min[0]=min(a[k].min[0],a[s].min[0]); a[k].max[0]=max(a[k].max[0],a[s].max[0]); a[k].min[1]=min(a[k].min[1],a[s].min[1]); a[k].max[1]=max(a[k].max[1],a[s].max[1]); } int build(int l,int r,int dd) { D=dd;int mid=(l+r)>>1; nth_element(a+l+1,a+mid+1,a+r+1,cmp); a[mid].min[0]=a[mid].max[0]=a[mid].d[0]; a[mid].min[1]=a[mid].max[1]=a[mid].d[1]; if (l!=mid) a[mid].l=build(l,mid-1,dd^1); if (mid!=r) a[mid].r=build(mid+1,r,dd^1); if (a[mid].l) up(mid,a[mid].l); if (a[mid].r) up(mid,a[mid].r); return mid; }
上述代码很好理解。
然后先在我要支持加入点,也是类似于线段树的思想:
void insert(int k) { int p=root;D=0; while (orzSYC) { up(p,k); if (a[k].d[D]<=a[p].d[D]){if (!a[p].l) {a[p].l=k;return;} p=a[p].l;} else {if (!a[p].r) {a[p].r=k;return;} p=a[p].r;} D^=1; } }
为什么我忽然觉得是splay的insert操作?就是每次往某个点的左或右(或者上或下)过去。
比如我们要查询与(x,y)最近的点(曼哈顿距离)与其的距离。
int getdis(int k) { int res=0; if (x<a[k].min[0]) res+=a[k].min[0]-x; if (x>a[k].max[0]) res+=x-a[k].max[0]; if (y<a[k].min[1]) res+=a[k].min[1]-y; if (y>a[k].max[1]) res+=y-a[k].max[1]; return res; } void ask(int k) { int d0=abs(a[k].d[0]-x)+abs(a[k].d[1]-y); if (d0<ans) ans=d0; int dl=(a[k].l)?getdis(a[k].l):INF; int dr=(a[k].r)?getdis(a[k].r):INF; if (dl<dr){if (dl<ans) ask(a[k].l);if (dr<ans) ask(a[k].r);} else {if (dr<ans) ask(a[k].r);if (dl<ans) ask(a[k].l);} }
getdis有点像Astar中的“估价函数”。计算(x,y)与当前点范围的差距有多少,然后按顺序遍历左二子和右儿子。这样,如果更新到最优值,就能及时退出。这种算法在随机数据上是lg的,但是在构造数据上约是sqrt的。
【BZOJ2716&2648】双倍经验。就是裸的K-D TREE模板套套。无压力1A~。
【BZOJ3053】哎,说多了都是泪。这道题调了不知道多少时间。首先,它拓展到了K维空间上。这样,cmp就只需判断某一位的大小就行了。然后要查询前m优值。因为m<=10,我为了效率,直接一遍做,开了一个数组记录最优值。然后判断最优值的时候裸O(n)(均摊)的更新答案。
对于那个估计函数也要稍稍改一下(因为是欧几里得距离),怎么方便怎么来!(反正只会影响到效率)
调了半天后,总算小数据对拍没有问题了~~浪交!T了。。。
后来我估计在更新答案时速度太慢,于是一咬牙,把10个最优解开成了队列......
然后大数据对拍~~什么,秒WA?这下真的调了一个下午(因为我是刚学的),后来发现:RZZ的博客里的nth过程用错了。比如从l到r,中间是mid(默认数组下标从1开始),应该是a+l,a+mid,a+r+1。最后一个+1因为是虚指针。但是前面都不用+1的(上面的代码已经修改过了)!!!!
最后又是RE。果断要数据!——发现只有一个测试点。我先写了个程序,把测试点拆成了好几个。然后一测:全过!和在一起:RE!原来,l和r要及时清零!!!呵呵,多么痛的领悟!
【截取程序】
var ss:string; cnt,n,m,a,i,j:longint; begin assign(input,'T.in'); reset(input); while (not(eof)) do begin inc(cnt); str(cnt,ss); ss:='T'+ss+'.in'; assign(output,ss); rewrite(output); readln(n,m); writeln(n,' ',m); for i:=1 to n do begin for j:=1 to m do begin read(a); write(a,' '); end; writeln; end; readln(n); writeln(n); for i:=1 to n do begin for j:=1 to m do begin read(a); write(a,' '); end; writeln; read(a); writeln(a); end; close(output); end; end.
【对拍造数据】
#include<cstdio> #include<cstdlib> #include<ctime> using namespace std; int main() { freopen("T.in","w",stdout); srand((int)time(0)); int n=50000,m=4,i,j; printf("%d %d\n",n,m); for (i=1;i<=n;i++) { for (j=1;j<=m;j++) printf("%d ",rand()%10000+1); printf("\n"); } int Q=1000; printf("%d\n",Q); while (Q--) { for (i=1;i<=m;i++) printf("%d ",rand()%10000+1); printf("%d\n",rand()%5+1); } return 0; }
【AC代码】
#include<cstdio> #include<algorithm> #include<queue> #define N 50005 #define INF 21390627567143.0 using namespace std; const int orzSYC=1; struct arr { int d[5],max[15],min[15],l,r,id; arr() {l=0;r=0;id=0;} }a[N*4],aa[N]; struct pop { double x;int id; friend bool operator < (const pop &a,const pop &b){return a.x<b.x;} }; priority_queue<pop>ans; int n,m,Q,i,j,t,x[15],D,temp[21],root,opt,P,flag; inline int Read() { int x=0;char ch=getchar();bool positive=1; for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') positive=0; for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0'; return positive?x:-x; } inline int cmp(arr a,arr b) { return a.d[D]<b.d[D]; } inline void up(int k,int s) { for (int i=0;i<m;i++) { a[k].min[i]=min(a[k].min[i],a[s].min[i]); a[k].max[i]=max(a[k].max[i],a[s].max[i]); } } int build(int l,int r,int dd) { D=dd;int mid=((l+r)>>1); nth_element(aa+l,aa+mid,aa+r+1,cmp); for (int i=0;i<m;i++) a[mid].min[i]=a[mid].max[i]=a[mid].d[i]=aa[mid].d[i]; a[mid].id=mid; if (l<mid) a[mid].l=build(l,mid-1,(dd+1)%m);else a[mid].l=0; if (mid<r) a[mid].r=build(mid+1,r,(dd+1)%m);else a[mid].r=0; if (a[mid].l) up(mid,a[mid].l); if (a[mid].r) up(mid,a[mid].r); return mid; } inline double sdis(int k) { double res=0; for (i=0;i<m;i++) { res+=(a[k].d[i]-x[i])*(a[k].d[i]-x[i]); } return res; } void ask(int k,int deep) { int L=a[k].l,R=a[k].r; if (x[deep]>=a[k].d[deep]) swap(L,R); double now=sdis(k); if (L) ask(L,(deep+1)%m); int flag=0; if (ans.size()<P) {ans.push((pop){now,k});flag=1;} else { if (now<ans.top().x) ans.pop(),ans.push((pop){now,k}); if ((x[deep]-a[k].d[deep])*1.*(x[deep]-a[k].d[deep])<ans.top().x) flag=1; } if (flag&&R) ask(R,(deep+1)%m); } int main() { while (scanf("%d%d",&n,&m)!=EOF) { for (i=1;i<=n;i++) for (j=0;j<m;j++) aa[i].d[j]=Read(); root=build(1,n,0); Q=Read(); while (Q--) { for (i=0;i<m;i++) x[i]=Read(); P=Read(); ask(root,0);int wri=0; printf("the closest %d points are:\n",P); while (!ans.empty()) { temp[++wri]=ans.top().id; ans.pop(); } for (i=wri;i;i--) { for (j=0;j<m-1;j++) printf("%d ",a[temp[i]].d[j]); printf("%d\n",a[temp[i]].d[m-1]); } } } return 0;