树的操作(换根树剖)

描述
XXX和 YYY在愉快地刷题。有一道题是这样的:给你一棵 n 个节点的有根树,每个节点有 一个权植。你要支持两种操作:查询以某棵树为根的子树的权值和,给以某个节点为根的整 棵子树的所有点的权值都加上一个值。机智的 XXX 很开心地用LLL教授 讲过的某些东西水水水水过了这道题。

但是可怕的出题人又增加了一种操作:将根节点改为第 u 号节点。于是XXX和YYY 就不会 做了。按照一惯的逻辑,这个问题被交给了你。

注意:初始时,根节点为 1 号节点,所有节点的权值都为 0。

输入
第一行有三个正整数 n, m,表示节点数和操作数。

接下来 n – 1 行每行两个正整数 p, q,表示 p 和 q 节点间有一条边。

接下来 m 行,每行首先有一个正整数 type,描述操作类型。

如果 type=1,接下来有一个正整数 u,表示询问以 u 为根的子树的权值和。

如果 type=2,接下来有两个正整数 u 和 v,表示把以 u 为根的子树的所有节点的权值都加上 v。

如果 type=3,接下来有一个正整数 u,表示把根节点改为第 u 号节点。

输出
对于每一个 1 号操作,输出一行一个非负整数表示答案。

**样例输入 **
3 4
1 2
2 3
2 2 2
1 3
3 3
1 3
样例输出
2
4
提示
【Hint】

对于 30%的数据,1 ≤ n, m ≤ 1,000。
另有 50%的数据不含有 3 号操作。
对于 100%的数据,1 ≤ n, m ≤ 200,000。
对于 100%的数据,1 ≤ u ≤ n, 1 ≤ v ≤ 500,000, 1 ≤ k ≤ n。

如果没有操作 3 3 3就是一道普通的树链剖分,但加入了换根的操作,怎么办?

我们需要考虑换根对维护子树和信息的影响,那么果断分类讨论一波:

我们设当前询问的子树的根为 u u u,当前的根节点为 r o o t root root

第一种情况: l c a lca lca( u u u, r o o t root root)!= u u u,那么我们发现在以 r o o t root root为根时 u u u原来的子树就是现在的子树,直接线段树区间查询。

第二种情况: u u u== r o o t root root,那么此时 u u u就是根节点,直接返回线段树根节点的 m a x n maxn maxn

第三种情况: r o o t root root u u u原来的子树内,即 l c a lca lca( r o o t root root, u u u)== u u u&& u u u!= r o o t root root,怎么做呢?
我们找到 u u u r o o t root root-> u u u这条链上 u u u的儿子 s s s,那么子树 s s s的补集就是我们查询的区间。

没有其他情况了。

然而这道题OJ上交会爆栈,因此写了个手工栈。

代码如下:

#include
#define N 200005
#define lc (p<<1)
#define rc (p<<1|1)
#define mid (T[p].l+T[p].r>>1)
using namespace std;
inline long long read(){
	long long ans=0,w=1;
	char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
	while(isdigit(ch))ans=(ans<<3)+(ans<<1)+ch-'0',ch=getchar();
	return ans*w;
}
struct Node{int l,r;long long sum,lz;}T[N<<2];
struct node{int v,next;}e[N<<1];
int dep[N],top[N],first[N],fa[N],siz[N],hson[N],num[N],n,m,root,cnt=0,tot=0;
inline void add(int u,int v){e[++cnt].v=v,e[cnt].next=first[u],first[u]=cnt;}
inline void dfs1(int p){
	siz[p]=1,hson[p]=0;
	for(int i=first[p];i;i=e[i].next){
		int v=e[i].v;
		if(v==fa[p])continue;
		dep[v]=dep[p]+1,fa[v]=p,dfs1(v),siz[p]+=siz[v];
		if(siz[v]>siz[hson[p]])hson[p]=v;
	}
}
inline void dfs2(int p,int tp){
	top[p]=tp,num[p]=++tot;
	if(hson[p])dfs2(hson[p],tp);
	for(int i=first[p];i;i=e[i].next){
		int v=e[i].v;
		if(v!=hson[p]&&v!=fa[p])dfs2(v,v);
	}
}
inline void pushup(int p){T[p].sum=T[lc].sum+T[rc].sum;}
inline void pushnow(int p,long long v){T[p].sum+=v*(T[p].r-T[p].l+1),T[p].lz+=v;}
inline void pushdown(int p){
	if(T[p].lz==0)return;
	pushnow(lc,T[p].lz),pushnow(rc,T[p].lz),T[p].lz=0;
}
inline void build(int p,int l,int r){
	T[p].l=l,T[p].r=r,T[p].sum=T[p].lz=0;
	if(l==r)return;
	build(lc,l,mid);
	build(rc,mid+1,r);
}
inline void update(int p,int ql,int qr,long long v){
	if(ql>T[p].r||T[p].l>qr)return;
	if(ql<=T[p].l&&T[p].r<=qr){pushnow(p,v);return;}
	pushdown(p);
	if(qr<=mid)update(lc,ql,qr,v);
	else if(ql>mid)update(rc,ql,qr,v);
	else update(lc,ql,mid,v),update(rc,mid+1,qr,v);
	pushup(p);
}
inline long long query(int p,int ql,int qr){
	if(ql>T[p].r||T[p].l>qr)return 0;
	if(ql<=T[p].l&&T[p].r<=qr)return T[p].sum;
	pushdown(p);
	if(qr<=mid)return query(lc,ql,qr);
	if(ql>mid)return query(rc,ql,qr);
	return query(lc,ql,mid)+query(rc,mid+1,qr);
}
inline int lca(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		x=fa[top[x]];
	}
	return dep[x]<dep[y]?x:y;
}
inline int fid(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		if(fa[top[x]]==y)return top[x];
		x=fa[top[x]];
	}
	if(dep[x]<dep[y])swap(x,y);
	return hson[y];
}
int main(){
	int size=40<<20;
	__asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));
	root=1;
	n=read(),m=read();
	for(int i=1;i<n;++i){
		int u=read(),v=read();
		add(u,v),add(v,u);
	}
	dfs1(1);
	dfs2(1,1);
	build(1,1,n);
	while(m--){
		int op=read(),u=read();
		if(op==3){root=u;continue;}
		if(op==2){
			long long v=read();
			if(u==root){update(1,1,n,v);continue;}
			int t=lca(u,root);
			if(u!=t){update(1,num[u],num[u]+siz[u]-1,v);continue;}
			int s=fid(u,root);update(1,1,n,v),update(1,num[s],num[s]+siz[s]-1,-v);
		}
		if(op==1){
			if(u==root){printf("%lld\n",T[1].sum);continue;}
			int t=lca(u,root);
			if(u!=t){printf("%lld\n",query(1,num[u],num[u]+siz[u]-1));continue;}
			int s=fid(u,root);printf("%lld\n",T[1].sum-query(1,num[s],num[s]+siz[s]-1));
		}
	}
	exit(0);
	return 0;
}

你可能感兴趣的:(#,线段树,#,小技巧,#,树链剖分)