维护树的直径·并查集
花花有一棵带 n 个顶点的树 T,每个节点有一个点权 ai 。
有一天,他认为拥有两棵树更好一些。所以,他从 T 中删去了一条边。
第二天,他认为三棵树或许又更好一些。因此,他又从他拥有的某一棵树中去除了一条边。
如此往复。每一天,花花都会删去一条尚未被删去的边,直到他得到了一个包含了 n 棵只有一个点的树的森林。
定义一条简单路径1的权值为路径上点权之和,一棵树的直径为树上权值最大的简单路径。
花花认为树最重要的特征就是它的直径。所以他想请你算出任一时刻他拥有的所有树的直径
的乘积。因为这个数可能很大,他要求你输出乘积对 109+7 取模之后的结果。
本来是很简单的一道题,然而蒟蒻并没有想到这个:
设两棵树 T1、T2 ,它们的直径的两个端点分别是 a1,b1、a2,b2 ,当他们合并时,新的树的直径的两个端点也必定在这4个点中,只要枚举比较一下6种情况即可。
这个性质很重要!
http://blog.csdn.net/rzo_kqp_orz/article/details/52280811
简单证明一下。
我们看作是 x 所在的连通块通过边(x, y)连向 y 所在的连通块。
若新直径不经过(x, y),则就是原来的两条直径取 max。
若新直径经过(x, y),就要考虑 x 延伸到哪儿、y 延伸到哪儿了。由直径的定义可知,x 能走到的最远点之一是 x 所在连通块直径的端点,y 同理。因此这时新直径的两个端点都是旧直径的端点。
(这个证明也适用于增量法求树的直径,即我给一棵树加一个新点,那么新直径必有一端点是旧直径的端点)
(注意只能是树,普通图没有这些性质)
然后并查集维护,合并的时候检查一下6种情况即可。
Code:
#include
#include
#include
#define D(x) cout<<#x<<" = "<
#define E cout<
using namespace std;
typedef long long ll;
const ll mod = 1e9+7;
const int N = 100005;
const int LOG = 21;
int n,w[N],edge[N][2],q[N],dis[N],fa[N][LOG],dep[N];
ll ans=1; int s[N],top;
ll pow(ll x,int k){
ll res=1;
while(k){
if(k&1) res=res*x%mod;
x=x*x%mod; k>>=1;
}
return res;
}
ll ni(ll x){ return pow(x,mod-2); }
void mul(ll &x,ll y){ x=x*y%mod; }
void div(ll &x,ll y){ x=x*ni(y)%mod; }
struct Edge{ int to,nxt; }e[N<<2]; int head[N],ec;
void add(int a,int b){ e[++ec].to=b; e[ec].nxt=head[a]; head[a]=ec; }
void dfs(int u,int f){
dep[u]=dep[f]+1; fa[u][0]=f; dis[u]=dis[f]+w[u];
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to; if(v==f)continue;
dfs(v,u);
}
}
void initlca(){
for(int j=1;jfor (int i=1;i<=n;i++)
fa[i][j]=fa[fa[i][j-1]][j-1];
}
int lca(int a,int b){
if(dep[a]int cha=dep[a]-dep[b];
for(int j=LOG-1;j>=0;j--)
if((cha>>j)&1) a=fa[a][j];
if(a!=b){
for(int j=LOG-1;j>=0;j--)
if(fa[a][j]!=fa[b][j])
a=fa[a][j],b=fa[b][j];
a=fa[a][0];
}
return a;
}
int dist(int u,int v){ int f=lca(u,v); return dis[u]+dis[v]-dis[f]-dis[fa[f][0]]; }
void upd(int u,int v,int &res,int &x,int &y){
int d=dist(u,v); if(d>res){ res=d; x=u,y=v; }
}
struct MergeSet{
int pa[N],g[N][2],length[N];
void init(){ for(int i=1;i<=n;i++) pa[i]=i, g[i][0]=g[i][1]=i, length[i]=w[i]; }
int find(int x){ if(x!=pa[x])pa[x]=find(pa[x]); return pa[x]; }
int merge(int a,int b){
if(dep[a]>dep[b])swap(a,b);
a=find(a); b=find(b);
int x,y,res=0;
upd(g[a][0],g[a][1],res,x,y);
upd(g[a][0],g[b][0],res,x,y);
upd(g[a][0],g[b][1],res,x,y);
upd(g[a][1],g[b][0],res,x,y);
upd(g[a][1],g[b][1],res,x,y);
upd(g[b][0],g[b][1],res,x,y);
g[a][0]=x; g[a][1]=y; length[a]=res; pa[b]=a;
return res;
}
int len(int a){ a=find(a); return length[a]; }
} ms;
int main(){
freopen("2.in","r",stdin);
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",w+i);
for(int i=1;i"%d%d",edge[i],edge[i]+1);
add(edge[i][0],edge[i][1]),add(edge[i][1],edge[i][0]);
}
for(int i=1;i"%d",q+i);
dfs(1,0); initlca(); ms.init();
for(int i=1;i<=n;i++)mul(ans,w[i]); s[++top]=ans;
for(int i=n-1;i>=1;i--){
int u=edge[q[i]][0], v=edge[q[i]][1];
if(dep[u]>dep[v])swap(u,v);
div(ans,ms.len(u)); div(ans,ms.len(v));
int res=ms.merge(u,v); mul(ans,res);
s[++top]=ans;
}
while(top){ printf("%d\n",s[top--]); }
}