经典的点分做法就不说了(然而我写点分t了。。)
线段树/平衡树启发式合并的话,就是维护子树每个节点到子树跟的距离,打一个整棵子树的标记,然后按dfs/bfs序启发式合并,合并之前先查询一下答案即可。
如果用线段树启发式合并的话,时间复杂度是 O(TNlog2N) 的,与点分一样;但是如果用splay启发式合并的话,时间复杂度存是 O(TNlogN) ,实际跑起来也确实比点分和线段树启发式合并快很多。
有一个定理(这里有论文。。论文(Part I)论文(Part II)但是好像要花钱什么的。。总之我并不知道该怎么看):假如splay中是一个1~n的排列,有m次操作,第i次操作是splay xi ,则总时间复杂度是 O(n+m+∑mi=2log(|xi−xi−1|+1)) 。所以假如我们把一个有a棵节点的树,按顺序插入b中,那么代价是 ∑ai=2log(|xi−xi−1|)≤alogba=a(logb−loga) (均值不等式)。所以总时间复杂度就是 O(nlogn) 的。
如果用其他的平衡树的话维护一些其他的信息也可以做到这个时间复杂度。。不过我不知道具体该怎么做。比如treap可以看这篇论文。
代码(splay):
#include<cstdio>
#include<iostream>
using namespace std;
#include<algorithm>
#include<cstring>
#include<cmath>
const int N=1e4+5;
int k;
int next[N<<1],succ[N<<1],w[N<<1],ptr[N],etot;
void addedge(int from,int to,int wt){
next[etot]=ptr[from],ptr[from]=etot,succ[etot]=to,w[etot++]=wt;
}
char * cp=(char *)malloc(20000000);
void in(int &x){
while(*cp<'0'||*cp>'9')++cp;
for(x=0;*cp>='0'&&*cp<='9';)x=x*10+(*cp++^'0');
}
struct SS{
int ch[2],fa;
int dis;
int size;
}bst[N];
int lazy[N],root[N];
void out(int node){
printf("bst[%d]={ch[0]=%d,ch[1]=%d,fa=%d,dis=%d,size=%d}\n",node,bst[node].ch[0],bst[node].ch[1],bst[node].fa,bst[node].dis,bst[node].size);
}
void pushup(int node){
bst[node].size=bst[bst[node].ch[0]].size+bst[bst[node].ch[1]].size+1;
}
void rot(int node){
int fa=bst[node].fa;
bool dir=bst[bst[node].fa].ch[1]==node;
bst[node].fa=bst[fa].fa;
bst[fa].fa=node;
bst[bst[node].ch[!dir]].fa=fa;
bst[fa].ch[dir]=bst[node].ch[!dir];
bst[node].ch[!dir]=fa;
bst[bst[node].fa].ch[bst[bst[node].fa].ch[1]==fa]=node;
pushup(fa);
}
void splay(int node){
for(int fa;bst[node].fa;rot(node))
if(bst[fa=bst[node].fa].fa)
if((bst[bst[fa].fa].ch[1]==fa)==(bst[fa].ch[1]==node))rot(fa);
else rot(node);
pushup(node);
//printf("splay(%d)\n",node);
//out(node);
}
int ans;
void query(int u,int uplimit){
//printf("query(%d,%d)\n",u,uplimit);
//printf("preans=%d\n",ans);
int node=root[u];
for(;;)
if(bst[node].dis<=uplimit){
ans+=bst[bst[node].ch[0]].size+1;
if(bst[node].ch[1])node=bst[node].ch[1];
else break;
}
else
if(bst[node].ch[0])node=bst[node].ch[0];
else break;
splay(node);
root[u]=node;
//printf("postans=%d\n",ans);
}
void add(int u,int x){
//printf("add(%d,%d)\n",u,x);
//out(x);
int node=root[u];
for(;;)
if(bst[node].dis<bst[x].dis)
if(bst[node].ch[1])node=bst[node].ch[1];
else{
bst[node].ch[1]=x;
break;
}
else
if(bst[node].ch[0])node=bst[node].ch[0];
else{
bst[node].ch[0]=x;
break;
}
bst[x].fa=node;
splay(x);
root[u]=x;
}
int d[N],dtot;
void dfs_bst(int node){
if(bst[node].ch[0])dfs_bst(bst[node].ch[0]);
d[dtot++]=node;
if(bst[node].ch[1])dfs_bst(bst[node].ch[1]);
}
void merge(int fa,int node){
if(bst[root[node]].size>bst[root[fa]].size)swap(node,fa);
//printf("merge %d->%d\n",node,fa);
dtot=0;
dfs_bst(root[node]);
//printf("dfs:");
//for(int i=0;i<dtot;++i)printf("%d ",d[i]);
//puts("");
for(int i=0;i<dtot;++i)query(fa,k-(bst[d[i]].dis+lazy[node])-lazy[fa]);
for(int i=0;i<dtot;++i){
bst[d[i]]=(SS){0,0,0,bst[d[i]].dis+lazy[node]-lazy[fa],1};
add(fa,d[i]);
}
root[node]=root[fa],lazy[node]=lazy[fa];
}
void dfs(int node,int ftr){
bst[node]=(SS){0,0,0,0,1};
root[node]=node;
for(int i=ptr[node];i;i=next[i])
if(succ[i]!=ftr){
dfs(succ[i],node);
lazy[succ[i]]+=w[i];
merge(node,succ[i]);
}
}
int main(){
fread(cp,1,20000000,stdin);
int n;
for(in(n),in(k);n||k;in(n),in(k)){
if(n==0){
puts("0");
continue;
}
memset(ptr,0,sizeof(ptr));
memset(lazy,0,sizeof(lazy));
ans=0;
etot=1;
int u,v,wt;
for(int i=n;--i;){
in(u),in(v),in(wt);
addedge(u,v,wt),addedge(v,u,wt);
}
dfs(1,0);
printf("%d\n",ans);
}
}
代码(线段树):
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<iostream>
#include<cstdlib>
using namespace std;
int next[20000],succ[20000],w[20000],ptr[10005],etot;
inline void addedge(int u,int v,int l){
next[etot]=ptr[u],ptr[u]=etot,w[etot]=l,succ[etot++]=v;
}
char * cp=(char *)malloc(10000000);
inline void in(int &x){
while(*cp<'0'||*cp>'9')++cp;
x=0;
while(*cp>='0'&&*cp<='9')x=x*10+(*cp++^'0');
}
int size[2500000],ls[2500000],rs[2500000],root[10005],stot,delta[260000];
#define maxn 10000000
#define lson ls[node],l,l+r>>1
#define rson rs[node],(l+r>>1)+1,r
int ans,K;
void add(int &node,int l,int r,int x,int A){
if(!node)node=stot++;
size[node]+=A;
if(l!=r)
if(x>l+r>>1)add(rson,x,A);
else add(lson,x,A);
}
int query(int node,int l,int r,int x){
int ans=0;
do{
if(x>l+r>>1){
ans+=size[ls[node]];
l=(l+r>>1)+1,node=rs[node];
}
else r=l+r>>1,node=ls[node];
}while(node&&l!=r);
return ans+size[node];
}
void find1(int node,int l,int r,int root,int to,int len){
if(l==r){
ans+=size[node]*query(to,-maxn,maxn,K-(l+delta[root]+len)-delta[to]);
/*cout<<"Get:"<<l+delta[root]<<endl; cout<<"Query:"<<K-(l+delta[root]+len)-delta[to]<<"->"<<query(to,-maxn,maxn,K-(l+delta[root]+len)-delta[to])<<"\n"; cout<<"Add:"<<l+delta[root]+len-delta[to]<<"\n\n";*/
return;
}
if(ls[node])find1(lson,root,to,len);
if(rs[node])find1(rson,root,to,len);
}
void find2(int node,int l,int r,int root,int to,int len){
if(l==r){
add(to,-maxn,maxn,l+delta[root]+len-delta[to],size[node]);
return;
}
if(ls[node])find2(lson,root,to,len);
if(rs[node])find2(rson,root,to,len);
}
inline void merge(int u,int v,int l){
if(size[root[u]]<size[root[v]]){
//cout<<"Paint("<<v<<") by "<<l<<endl;
delta[root[v]]+=l;
find1(root[u],-maxn,maxn,root[u],root[v],0);
find2(root[u],-maxn,maxn,root[u],root[v],0);
root[u]=root[v];
}
else{
find1(root[v],-maxn,maxn,root[v],root[u],l);
find2(root[v],-maxn,maxn,root[v],root[u],l);
}
}
inline void dfs(int node,int ftr){
for(int i=ptr[node];i;i=next[i])
if(succ[i]!=ftr){
dfs(succ[i],node);
//cout<<"-----Merge:"<<node<<","<<succ[i]<<"------\n";
merge(node,succ[i],w[i]);
}
}
int main(){
fread(cp,1,10000000,stdin);
int N,u,v,i,j,l;
for(in(N),in(K);N||K;in(N),in(K)){
if(!N){
puts("0");
continue;
}
memset(ptr+1,0,sizeof(int)*N);
etot=1;
memset(size,0,sizeof(int)*stot);
memset(ls,0,sizeof(int)*stot);
memset(rs,0,sizeof(int)*stot);
memset(root+1,0,sizeof(int)*N);
stot=1;
ans=0;
for(i=N;--i;){
in(u),in(v),in(l);
addedge(u,v,l),addedge(v,u,l);
}
for(i=N;i;--i){
add(root[i],-maxn,maxn,0,1);
delta[root[i]]=0;
}
dfs(rand()%N+1,0);
printf("%d\n",ans);
}
}
代码(点分 TLE):
#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
using namespace std;
int ans;
int K;
void in(int &x){
char c=getchar();
x=0;
while((c<'0'||c>'9')&&c!=-1)c=getchar();
if(c==-1)exit(0);
for(;c>='0'&&c<='9';c=getchar())x=x*10+c-'0';
}
int next[20005],succ[20005],ptr[10005],w[20005],etot;
void addedge(int from,int to,int l){
next[etot]=ptr[from],ptr[from]=etot,succ[etot]=to,w[etot++]=l;
}
int q[10005],qx[10005],fa[10005];
bool p[10005];
int size[10005],dis[10005];
int h,t;
#define inf 0x7fffffff
void bfs(int node){
for(h=0;h!=t;++h)fa[q[h]]=0;
h=0,t=1;
int i;
qx[node]=0;
q[h]=node;
for(;h!=t;++h){
qx[q[h]]=h;
//printf("%d-x->%d ",q[h],h);
for(i=ptr[q[h]];i;i=next[i])
if(succ[i]!=fa[q[h]]&&!p[succ[i]]){
fa[succ[i]]=q[h];
q[t++]=succ[i];
}
}
//puts("");
}
void vdfs(int node,int fdis){
//printf("----%d,%d-----\n",node,fdis);
//计算-距离
bfs(node);
int i,j=t-1;
//printf("t:%d\n",t);
if(K>=(long long)fdis<<1){
for(h=0;h!=t;++h)
for(i=ptr[q[h]];i;i=next[i])
if(succ[i]!=fa[q[h]]&&!p[succ[i]])
dis[qx[succ[i]]]=dis[h]+w[i];
sort(dis,dis+t);
K-=fdis<<1;
for(i=0;i<t&&dis[i]<=K;++i){
while(dis[i]+dis[j]>K)--j;
ans-=j+1;
}
K+=fdis<<1;
//printf("----------->%d\n",ans);
}
//寻找重心
int Maxson;
while(h--){
size[q[h]]=1;
Maxson=0;
for(i=ptr[q[h]];i;i=next[i])
if(succ[i]!=fa[q[h]]&&!p[succ[i]]){
size[q[h]]+=size[succ[i]];
Maxson=max(Maxson,size[succ[i]]);
}
if(max(Maxson,t-Maxson-1)<=t>>1){
node=q[h];
//printf("CenterV:%d(%d)\n",node,max(Maxson,t-Maxson-1));
break;
}
}
//计算+距离
bfs(node);
for(h=0;h!=t;++h)
for(i=ptr[q[h]];i;i=next[i])
if(succ[i]!=fa[q[h]]&&!p[succ[i]])
dis[qx[succ[i]]]=dis[h]+w[i];
sort(dis,dis+t);
/*for(i=0;i<t;++i)printf("%d ",dis[i]); puts("");*/
j=t-1;
for(i=0;i<t&&dis[i]<=K;++i){
while(dis[j]+dis[i]>K)--j;
ans+=j+1;
}
//printf("+++++++++>%d\n",ans);
//向儿子们进发!
p[node]=1;
for(i=ptr[node];i;i=next[i])
if(!p[succ[i]])
vdfs(succ[i],w[i]);
}
int main(){
int N,a,b,l;
for(in(N),in(K);N||K;in(N),in(K)){
if(N==0){
puts("0");
continue;
}
etot=1;
memset(ptr,0,sizeof(int)*(N+1));
memset(p,0,sizeof(p));
ans=-N;
while(--N){
in(a),in(b),in(l);
addedge(a,b,l),addedge(b,a,l);
}
vdfs(1,2000000000);
printf("%d\n",ans>>1);
}
}
总结:
①splay的时间复杂度是 O(nlogn+∑log(|xi−xi−1|+1)) 。
②splay rot以后要pushup(fa)!这个地方写错了无数次了。