第一次写点分治啊,果然还是黄学长的代码框架好。。。
这绝对是点分治的最经典的题目了。大概写一下自己的理解吧。
首先,对于一棵树,求出其重心并作为根节点。然后链就可以分为两类:经过根节点的和不经过根节点的。对于不经过根节点的,在子树中递归调用即可。由于每次根节点都取树的重心,所以递归一次点的个数至少除以2,递归层数不超过logN层。另一方面,每一层都可以大致看成有O(N)级别个点,对于这些点的操作时间是O(NlogN)级别的。因此总的时间复杂度O(Nlog^2N)。
对于经过根节点的链,我们进行如下操作。首先将根节点深度定为0,然后求出子树中每个节点的深度dep[i]。然后将dep[]排序,如果dep[x]+dep[y]<=k,显然就是满足条件的链。排序后O(N)扫一遍即可。
但是上述操作仍有许多值得考究的地方。首先,仅仅只有dep[x]+dep[y]<=k是不能充分说明经过根节点的,还有可能x和y是在用一个子树中的。但如果直接强行求出经过根节点即x和y在不同子树同时满足dep[x]+dep[y]<=k的话是很满发的。所以可以先求出满足dep[x]+dep[y]<=k的链的个数sum,然后对于每一刻子树,再求一遍满足dep[x]+dep[y]<=k的个数tmp,在sum中减去所有的tmp即可。这样算法就被正确地实现了。
可是O(Nlog^2N)的时间复杂度好像不是很好看啊。实际上,多出来的一个logN主要是耗时在sort上。在求平面最近点对时我们就可以用在子节点(差不多这个意思)中排序然后再归并的方法。这里对于点分治的dep数组也是一样的,可以在子树中排完序然后再归并。这样连求dep的数组都不需要了。不过这个归并实现比较麻烦,因为有多个子树,可能要用堆维护,再考虑到常数的差距可能反而更慢,以及实现的复杂度。我这么弱暂时不考虑。
下面给出AC代码:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define inf 1000000000 #define N 300005 using namespace std; int n,m,cnt,tot,rt,sum,ans,fst[N],pnt[N],len[N],nxt[N],c[N],d[N],sz[N],f[N]; bool vis[N]; int read(){ int x=0; char ch=getchar(); while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); } return x; } void add(int aa,int bb,int cc){ pnt[++tot]=bb; nxt[tot]=fst[aa]; len[tot]=cc; fst[aa]=tot; } void dfs(int x,int last){ sz[x]=f[x]=1; int p; for (p=fst[x]; p; p=nxt[p]){ int q=pnt[p]; if (q==last || vis[q]) continue; dfs(q,x); sz[x]+=sz[q]; f[x]=max(f[x],sz[q]); } f[x]=max(f[x],sum-sz[x]); if (f[x]<f[rt]) rt=x; } void getdep(int x,int last){ c[++cnt]=d[x]; int p; for (p=fst[x]; p; p=nxt[p]){ int q=pnt[p]; if (q==last || vis[q]) continue; d[q]=d[x]+len[p]; getdep(q,x); } } int work(int x,int dep){ d[x]=dep; cnt=0; getdep(x,0); sort(c+1,c+cnt+1); int tmp=0,l,r=cnt; for (l=1; l<r; l++){ while (l<r && c[l]+c[r]>m) r--; tmp+=r-l; } return tmp; } void solve(int x){ ans+=work(x,0);vis[x]=1; int p; for (p=fst[x]; p; p=nxt[p]){ int q=pnt[p]; if (vis[q]) continue; ans-=work(q,len[p]); rt=0; sum=sz[q]; dfs(q,x); solve(rt); } } int main(){ while (n=read()){ m=read(); int i; tot=ans=0; memset(fst,0,sizeof(fst)); memset(vis,0,sizeof(vis)); for (i=1; i<n; i++){ int x=read(),y=read(),z=read(); add(x,y,z); add(y,x,z); } sum=n; f[rt=0]=inf; dfs(1,0); solve(rt); printf("%d\n",ans); } return 0; }
2015.11.15
by lych