Time Limit: 30 Sec Memory Limit: 512 MB
Submit: 292 Solved: 150
[Submit][Status][Discuss]
Description
已知平面内 N 个点的坐标,求欧氏距离下的第 K 远点对。
Input
输入文件第一行为用空格隔开的两个整数 N, K。接下来 N 行,每行两个整数 X,Y,表示一个点
的坐标。1 < = N < = 100000, 1 < = K < = 100, K < = N*(N−1)/2 , 0 < = X, Y < 2^31。
Output
输出文件第一行为一个整数,表示第 K 远点对的距离的平方(一定是个整数)。
Sample Input
10 5
0 0
0 1
1 0
1 1
2 0
2 1
1 2
0 2
3 0
3 1
Sample Output
9
先把所有点都放到 kd 树里去,维护一个小根堆,枚举每个点。每次在 kd 树种查询的时候,就相当于当前的最优解是堆顶,像查最远点对一样在 kd 树里查就行了。
有个问题就是一对点会被算 2 遍,所以堆中的元素个数是 2k 个。
#include
#include
#include
#include
#include
using namespace std;
#define LL long long
const int N=100010;
LL now;
int n,root,D,K,siz;
priority_queue q;
struct S{
int l,r;
LL d[2],mi[2],ma[2];
LL &operator [] (int x){
return d[x];
}
bool operator < (const S &x)const{
return d[D]<x.d[D];
}
}tr[N],p[N];
inline int in(){
int x=0;char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return x;
}
inline void update(int k){
int l=tr[k].l,r=tr[k].r,i;
for(i=0;i<=1;++i){
tr[k].mi[i]=tr[k].ma[i]=tr[k][i];
if(l){
tr[k].mi[i]=min(tr[k].mi[i],tr[l].mi[i]);
tr[k].ma[i]=max(tr[k].ma[i],tr[l].ma[i]);
}
if(r){
tr[k].mi[i]=min(tr[k].mi[i],tr[r].mi[i]);
tr[k].ma[i]=max(tr[k].ma[i],tr[r].ma[i]);
}
}
}
#define mid (l+r)/2
inline int build(int l,int r,int flag){
if(l>r) return 0;
D=flag;nth_element(p+l,p+mid,p+r+1);
tr[mid]=p[mid];
tr[mid].l=build(l,mid-1,flag^1);
tr[mid].r=build(mid+1,r,flag^1);
update(mid);
return mid;
}
inline LL my_abs(LL x){return x<0?-x:x;}
inline LL calc(LL x1,LL y1,LL x2,LL y2){
return (x1-x2)*(x1-x2)+(y1-y2)*(y1-y2);
}
inline void work(int k,LL x,LL y){
if(!k) return ;
int l=tr[k].l,r=tr[k].r;
LL value=calc(tr[k][0],tr[k][1],x,y);
if(value>now){
if(siz==K<<1) q.pop();
else ++siz;
q.push(-value);
now=(siz==K<<1)?-q.top():0LL;
}
LL o0=max(my_abs(x-tr[l].mi[0]),my_abs(x-tr[l].ma[0]));
LL o1=max(my_abs(y-tr[l].mi[1]),my_abs(y-tr[l].ma[1]));
LL ans1=o0*o0+o1*o1;
o0=max(my_abs(x-tr[r].mi[0]),my_abs(x-tr[r].ma[0]));
o1=max(my_abs(y-tr[r].mi[1]),my_abs(y-tr[r].ma[1]));
LL ans2=o0*o0+o1*o1;
if(ans1>=ans2){
if(ans1>now) work(tr[k].l,x,y);
if(ans2>now) work(tr[k].r,x,y);
}
else{
if(ans2>now) work(tr[k].r,x,y);
if(ans1>now) work(tr[k].l,x,y);
}
}
int main(){
int i,x,y,j,o=0;
n=in();K=in();
for(i=1;i<=n;++i){
x=in();y=in();
p[i][0]=(LL)x;p[i][1]=(LL)y;
for(j=0;j<=1;++j)
p[i].mi[j]=p[i].ma[j]=p[i][j];
}
root=build(1,n,0);
for(i=1;i<=n;++i)
work(root,p[i][0],p[i][1]);
printf("%I64d\n",now);
}