给定相同点集上的两棵生成树 T1 和 T2,节点编号为 1 ∼ N。对于 T1 中的每条边 e1,你需要
求在 T2 中有多少条边 e2 满足:
T1 − e1 + e2(从 T1 中删去 e1 再加上 e2 构成的图)是一棵生成树;
T2 − e2 + e1 也是一棵生成树。
2 ≤ N ≤ 2*10^5
不难发现如果两条边可以互相替代,则必须要满足e2在T2中e1两端点的路径上,反过来也要满足。
那么我们先把T2的每条边在T1上差分一下,就是在两端点处插入,然后在lca处删除,那么限制条件就被去掉了一个。
把边下放到点,现在我们要实现的单点修改,和支持链上权值查询。
可以用树链剖分做到俩log,也可以直接维护入栈出栈序做到一个log。然后把线段树自底向上合并就好了。
复杂度 O(nlogn) O ( n l o g n ) 。
#include
#include
#include
#include
#include
#include
const int N=200005;
int n,tot,sz,in[N],out[N],rt[N],ans[N];
struct tree{int l,r,s;}t[N*60];
std::vector<int> vec[N];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
struct Tree
{
int cnt,last[N],fa[N],dep[N],size[N],top[N];
struct edge{int to,next;}e[N*2];
void addedge(int u,int v)
{
e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}
void dfs1(int x)
{
dep[x]=dep[fa[x]]+1;size[x]=1;
for (int i=last[x];i;i=e[i].next)
{
if (e[i].to==fa[x]) continue;
fa[e[i].to]=x;
dfs1(e[i].to);
size[x]+=size[e[i].to];
}
}
void dfs2(int x,int chain)
{
top[x]=chain;int k=0;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa[x]&&size[e[i].to]>size[k]) k=e[i].to;
if (k) dfs2(k,chain);
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa[x]&&e[i].to!=k) dfs2(e[i].to,e[i].to);
}
void dfs3(int x)
{
in[x]=++tot;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa[x]) dfs3(e[i].to);
out[x]=++tot;
}
int get_lca(int x,int y)
{
while (top[x]!=top[y])
{
if (dep[top[x]]y]]) std::swap(x,y);
x=fa[top[x]];
}
return dep[x]y]?x:y;
}
}t1,t2;
int newnode()
{
sz++;t[sz].l=t[sz].r=t[sz].s=0;return sz;
}
int merge(int x,int y)
{
if (!x||!y) return x+y;
t[x].s+=t[y].s;
t[x].l=merge(t[x].l,t[y].l);
t[x].r=merge(t[x].r,t[y].r);
return x;
}
void modify(int &d,int l,int r,int x,int y)
{
if (!d) d=newnode();
t[d].s+=y;
if (l==r) return;
int mid=(l+r)/2;
if (x<=mid) modify(t[d].l,l,mid,x,y);
else modify(t[d].r,mid+1,r,x,y);
}
int query(int d,int l,int r,int x,int y)
{
if (!d||x<=l&&r<=y) return t[d].s;
int mid=(l+r)/2,ans=0;
if (x<=mid) ans+=query(t[d].l,l,mid,x,y);
if (y>mid) ans+=query(t[d].r,mid+1,r,x,y);
return ans;
}
void solve(int x)
{
for (int i=t2.last[x];i;i=t2.e[i].next)
{
if (t2.e[i].to==t2.fa[x]) continue;
solve(t2.e[i].to);
rt[x]=merge(rt[x],rt[t2.e[i].to]);
}
for (int i=0;i<vec[x].size();i++)
{
int y=vec[x][i];
if (y>0) modify(rt[x],1,tot,in[y],1),modify(rt[x],1,tot,out[y],-1);
else modify(rt[x],1,tot,in[-y],-2),modify(rt[x],1,tot,out[-y],2);
}
if (x==1) return;
int lca=t1.get_lca(x,t2.fa[x]);
ans[x]=query(rt[x],1,tot,1,in[x])+query(rt[x],1,tot,1,in[t2.fa[x]])-2*query(rt[x],1,tot,1,in[lca]);
}
void clear()
{
sz=tot=t1.cnt=t2.cnt=0;
for (int i=1;i<=n;i++) t1.last[i]=t2.last[i]=rt[i]=0,vec[i].clear();
}
int main()
{
int T=read();
while (T--)
{
n=read();
clear();
for (int i=1;iint x=read(),y=read();
t2.addedge(x,y);
}
for (int i=1;iint x=read(),y=read();
t1.addedge(x,y);
}
t1.dfs1(1);t1.dfs2(1,1);t1.dfs3(1);
t2.dfs1(1);t2.dfs2(1,1);
for (int i=1;i*2-1;i+=2)
{
int x=t1.e[i].to,y=t1.e[i+1].to,lca=t2.get_lca(x,y);
if (t1.dep[x]y]) std::swap(x,y);
vec[x].push_back(x);vec[y].push_back(x);vec[lca].push_back(-x);
}
solve(1);
for (int i=1;i*2-1;i+=2)
{
int x=t2.e[i].to,y=t2.e[i+1].to;
if (t2.dep[x]y]) std::swap(x,y);
printf("%d ",ans[x]);
}
puts("");
}
return 0;
}