题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1007
题目大意:在平面上有N个点,求出两点之间距离的最小值/2,就是结果.
算法详细介绍:http://blog.csdn.net/guyulongcs/article/details/6841550,这里讲得很清楚。
也就是一个很裸的算法题吧,要求用O(nlogn)的算法求出最近点对,翻了翻算法导论,看完了上面用分治法解答,自己实现的时候有几个点需要注意:
1.如果每次递归求解的时候,开出新的数组来保存,集合Sy的划分,那么一定会MLE,看了别人的代码后,发现可以用先分解,再归并,那么只需要增加一个辅助数组,解决了MLE的问题。
2.关于如何分解原问题,开始用的是算法导论上说的,以Sx的中间点的横坐标来划分,分解成<=Sx[mid].x的和>=Sx[mid].x的,这样只用一个辅助数组的方法不适用,因为可能有很多点的横坐标相同,再想到分解成<=Sx[mid].x的和>Sx[mid].x的,这样也行不通,因为可能导致>Sx[mid].x的部分没有点,在某算法模板上看到,可以按照元素在最开始的Sx中的顺序,为每个元素增加一个域index来记录该元素在Sx中处于什么位置,这样就可以把问题分解成<=Sx[mid].index的和>Sx[mid].index的,没有问题。
3.总结最近点对的整体算法框架:
(1)预处理,用两个集合Sx,Sy保存所有的点,不同的是Sx按照x升序排列,Sy按照y升序排列。
(2)递归求解:边界条件为,当前待处理点集的点的个数<=3
递归框架为,分解Sx,分解Sy,并保证Sy的两个子集按y的升序排列,递归求解,用归并将Sy还原到没分解以前,合并子问题。
4.关于合并子问题的正确性证明还有待研究。
AC代码:
#include<iostream> #include<cstring> #include<cstdio> #include<cmath> #include<algorithm> using namespace std; const int MAXN = 100010; const double inf = 10e100; #define SQL(x) (x)*(x) struct Point { double x,y; int index; }sx[MAXN],sy[MAXN],st[MAXN]; bool x_cmp(const Point &p1,const Point &p2) { return p1.x<p2.x; } bool y_cmp(const Point &p1,const Point &p2) { return p1.y<p2.y; } double dis(Point p1,Point p2) { return sqrt(SQL(p1.x-p2.x)+SQL(p1.y-p2.y)); } double merge(Point sy[],Point st[],int l,int r,double dist,double L) { int Len = 0; for(int i=l;i<=r;i++) { if(sy[i].x>(L-dist)&&sy[i].x<(L+dist))st[Len++]=sy[i]; } for(int i=0;i<Len;i++) { for(int j=i+1;j<Len&&j<i+8;j++) { dist = min(dist,dis(st[i],st[j])); } } return dist; } double solve(Point sx[],Point sy[],Point st[],int l,int r) { if(r==l)return inf; else if(r-l==1)return dis(sx[l],sx[r]); else if(r-l==2) { double x1 = dis(sx[l],sx[l+1]); double x2 = dis(sx[l],sx[r]); double x3 = dis(sx[l+1],sx[r]); x1 = min(x1,x2); x1 = min(x1,x3); return x1; } else { int m = (l+r)>>1,i,j,k; double L = sx[m].x, dist; for(i=l,j=l,k=m+1;i<=r;) { if(sy[i].index<=sx[m].index)st[j++]=sy[i++]; else st[k++]=sy[i++]; } //printf("%d %d ~~~\n",j,k); for(i=l;i<=r;i++)sy[i]=st[i];//,printf("%.2lf %.2lf\n",st[i].x,st[i].y);system("pause"); double p = solve(sx,sy,st,l,m),q = solve(sx,sy,st,m+1,r); //printf("%.2lf %.2lf\n",p,q);system("pause"); dist = min(p,q); for(i=l,j=l,k=m+1;j<=m&&k<=r;) { if(sy[j].y<sy[k].y)st[i++]=sy[j++]; else if(sy[j].y>sy[k].y)st[i++]=sy[k++]; else { if(sy[j].index<sy[k].index)st[i++]=sy[j++]; else st[i++] = sy[k++]; } } while(j<=m)st[i++]=sy[j++]; while(k<=r)st[i++]=sy[k++]; for(i=l;i<=r;i++)sy[i]=st[i]; //system("pause"); dist = merge(sy,st,l,r,dist,L); //printf("%.2lf\n",dist); return dist; } } int main() { int n; //freopen("in.txt","r",stdin); //freopen("out2.txt","w",stdout); while(scanf("%d",&n),n) { for(int i=0;i<n;i++) scanf("%lf%lf",&sx[i].x,&sx[i].y); stable_sort(sx,sx+n,x_cmp); for(int i=0;i<n;i++) sx[i].index = i,sy[i]=sx[i]; stable_sort(sy,sy+n,y_cmp); //for(int i=0;i<n;i++)printf("%.2lf %.2lf %d\n",sx[i].x,sx[i].y,sx[i].index); //printf("\n"); //for(int i=0;i<n;i++)printf("%.2lf %.2lf %d\n",sy[i].x,sy[i].y,sy[i].index); printf("%.2lf\n",solve(sx,sy,st,0,n-1)/2); } }