树链剖分 讲解+模板+习题

今天我们来讲一下树链剖分

树链剖分是什么?

树链剖分是一种用来维护树上路径信息的在线方法,可以处理在线。

通常通过一种方法,将一棵树剖分成若干条链,然后通过数据结构(线段树,BIT等)去维护。

我们通常所说的树链剖分,基本都是轻重链剖分。

下面我们介绍一下这一种剖分。

学习树链剖分的基础知识有lca,dfs序,线段树等。

首先,我们来明确即可变量。

s i z e i size_i sizei表示以 i i i为子树的大小,包括 i i i

h e a v y i heavy_i heavyi表示 i i i的所有儿子 j j j s i z e j size_j sizej最大的一个

连接 i i i h e a v y i heavy_i heavyi的边称为重边( h e a v y heavy heavy e d g e s edges edges),其余为轻边( l i g h t light light e d g e s edges edges)

当好多条重边首尾相连,形成了一条更大的重边时,我们称这个重边的集合叫重链( h e a v y heavy heavy p a t h path path

因为不可能有两条重链相交(根据重链定义可知),所以整棵树被划分成了若干条互不相交的重链。

举个例子

树链剖分 讲解+模板+习题_第1张图片

红色的边是重边,黑色是轻边。1-2-4-8-16构成了一条重链。

其中所有的重链都不相交

再给一些剖分的性质。

  1. 每个点只属于一个重链。
  2. 不可能有两条重链相邻。
  3. 从一个点开始,重复"重链顶端–>跳一条轻边"这个过程,必定能到根
  4. 按上述方法跳,必定能在 O ( l o g n ) O(logn) O(logn)步到达顶点。

给一下第4条的证明。

如果一条边是重链,那他一次肯定能跳到重链顶端,这样很快,

而没经过一条轻边,根据重链的定义,轻边所在子树 s i z e size size< 1 2 \frac{1}{2} 21 所有子树 s i z e size size ,则没经过一条轻边,该点得子树大小必定 ∗ 2 *2 2甚至更多,则只会有 O ( l o g n ) O(logn) O(logn)次。

那我们怎么通过代码处理轻重边剖分?

一种经典的处理方法如下。


f a t h e r x father_x fatherx表示 x x x的父亲节点
s i z e x size_x sizex表示 x x x的子树大小
d e e p x deep_x deepx表示 x x x的深度
h e a v y x heavy_x heavyx表示 x x x的重儿子
t o p x top_x topx 表示 x x x的重链顶点

处理以上几个数组通常使用两遍dfs处理。

第一次构建前4个数组

第二次本质是对重链进行标号与整理

每次遍历到某点时,先递归它的重儿子,

在递归他的轻儿子,它的重链顶端就是自己

顺便处理 t o p x top_x topx

两遍dfs代码:

void dfs1(int rt){
	size[rt]=1;
	for (int i=0;i<e[rt].size();i++){
		int to=e[rt][i] ;
		if (to==fa[rt]) continue ; 
		fa[to]=rt;dep[to]=dep[rt]+1 ;//信息 
		dfs1(to) ;
		size[rt]+=size[to] ;
		if (!hson[rt] || size[to]>size[hson[rt]]) hson[rt]=to ;//处理重儿子 
	}
} 
void dfs2(int rt,int t){
	top[rt]=t ;
	if (!hson[rt]) return ;
	dfs2(hson[rt],t) ;//先递归重儿子 
	for (int i=0;i<e[rt].size();i++){
		int to=e[rt][i] ;
		if (to==fa[rt] || to==hson[rt]) continue ;
		dfs2(to,to) ;//轻边顶端为自己 
	}
}

之后,我们用线段树维护每条重链的信息(在实际操作中,轻边也算作重链)

在此同时,我们需要记录每个节点的时间戳 d f n x dfn_x dfnx,同样最好记录每个 d f n x dfn_x dfnx对应的节点,我们用 s e q x seq_x seqx表示。

这两个操作也非常简单,只需在dfs2初始时加这两句话:

	dfn[rt]=++tot ;
	seq[tot]=rt ;

因为我们是先处理重儿子的,所以重链的节点肯定会被放在线段树的同一个区间中,这样方便线段树操作。

之后上面的剖分过程大体就结束了。

树链剖分能够暴力求出两两点的 l c a lca lca

为何说暴力,因为它不像倍增一样要枚举 2 j 2^j 2j步,他就是每次跳,而且时间复杂度有保障!

我们来讲一讲这个过程。

先上一张图

树链剖分 讲解+模板+习题_第2张图片

如果要查询9和11的lca,我们手动模拟一下:

1.首先11的dep较大,11跳至1( f a [ t o p [ 11 ] ] fa[top[11]] fa[top[11]]
因为 t o p [ 9 ] = t o p [ 11 ] , top[9]=top[11], top[9]=top[11]所以结束循环,判定深度小的位 l c a lca lca即可

void lca(int x,int y){
	int fx=top[x],fy=top[y] ;
	while(fx!=fy){
		if (dep[x]<dep[y]){ //跳x 
			swap(x,y) ;
			swap(fx,fy) ;
		}
		x=fa[fx];
		fx=top[x] ;
	}
	if (dep[x]>dep[y]) swap(x,y) ;
	return y ;
} 

了解了树剖解救lca的过程,基本你已经掌握了树剖了,只差来几道例题,我们不妨讲解几道例题,更深入了解树剖一下。

我们拿BZOJ 1036 树的统计 举个例子。

它让你动态干三件事:

  1. 把某个点的权值改成t
  2. 询问x到y的路径上的节点权值的最大值
  3. 询问x到y的路径上的所以节点权值的和

这是一个树链剖分的裸题。

他让我们动态维护两两点的路径信息。

假设我们剖分写好了,线段树也建好了,我们该怎么求出x到y的路径上的节点权值的最大值和总和呢?

举个例子。

树链剖分 讲解+模板+习题_第3张图片

同样是9和11的例子。

11跳到1,我们已经维护2~11的最大值和和,为3和6

之后维护1到9的最大值和和,为5和10

MAX=5 SUM=16

int linkmax(int x,int y){ //链上最大值
	int fx=top[x],fy=top[y],ans=-inf ;
	while(fx!=fy){
		if (dep[fx]<dep[fy]){
			swap(x,y) ;
			swap(fx,fy) ;
		} 
		ans=max(ans,qmax(1,dfn[fx],dfn[x])) ;
		x=fa[fx] ;
		fx=top[x] ;
	}
	if (dep[x]>dep[y]) swap(x,y) ;//在同一条重链上 
	ans=max(ans,qmax(1,dfn[x],dfn[y])) ;
	return ans ;
}
int linksum(int x,int y){ //链上和
	int fx=top[x],fy=top[y],ans=0 ;
	while(fx!=fy){
		if (dep[fx]<dep[fy]){
			swap(x,y) ;
			swap(fx,fy) ;
		}
		ans+=qsum(1,dfn[fx],dfn[x]) ;
		x=fa[fx] ;
		fx=top[x] ;
	}
	if (dep[x]>dep[y]) swap(x,y) ;
	ans+=qsum(1,dfn[x],dfn[y]) ;
	return ans ;
}

建议自己手动模拟一下,对于理解算法有很大作用

代码:

#include 
using namespace std ;
const int N = 30010 ;
const int inf = (1<<30) ;
#define rep(i,a,b) for (int (i)=(a);(i)<=(b);(i)++)
#define REP(i,a,b) for (int (i)=(a);(i)>=(b);(i)--)
#define ls ((rt)<<1)
#define rs ((rt)<<1|1)
typedef long long ll ;

vector <int> e[N] ;
int a[N],size[N],fa[N],dep[N],hson[N],top[N],dfn[N],seq[N] ;
int n,Q,tot,x,y,u,t ;
char op[20] ;

struct node{int l,r,Max,sum;}tr[N<<2];

void dfs1(int rt){
	size[rt]=1;
	for (int i=0;i<e[rt].size();i++){
		int to=e[rt][i] ;
		if (to==fa[rt]) continue ; 
		fa[to]=rt;dep[to]=dep[rt]+1 ;
		dfs1(to) ;
		size[rt]+=size[to] ;
		if (!hson[rt] || size[to]>size[hson[rt]]) hson[rt]=to ;
	}
} 
void dfs2(int rt,int t){
	top[rt]=t ;
	dfn[rt]=++tot ;
	seq[tot]=rt ;
	if (!hson[rt]) return ;
	dfs2(hson[rt],t) ;
	for (int i=0;i<e[rt].size();i++){
		int to=e[rt][i] ;
		if (to==fa[rt] || to==hson[rt]) continue ;
		dfs2(to,to) ;
	}
}
inline void pushup(int rt){
	tr[rt].Max=max(tr[ls].Max,tr[rs].Max) ;
	tr[rt].sum=tr[ls].sum+tr[rs].sum ;
} 
void build(int rt,int l,int r){
	tr[rt].l=l;tr[rt].r=r ;
	if (l==r){
		tr[rt].Max=tr[rt].sum=a[seq[l]] ;
		return ; 
	}
	int mid=(l+r)>>1;
	build(ls,l,mid) ;
	build(rs,mid+1,r) ;
	pushup(rt) ;
}
void modify(int rt,int pos){
	if (tr[rt].l==tr[rt].r){
		tr[rt].Max=tr[rt].sum=a[seq[tr[rt].l]] ;
		return ;
	}
	int mid=(tr[rt].l+tr[rt].r)>>1 ;
	if (pos<=mid) modify(ls,pos) ;
	else modify(rs,pos) ;
	pushup(rt) ;
}
int qmax(int rt,int l,int r){
	if (l<=tr[rt].l && tr[rt].r<=r) return tr[rt].Max ;
	int mid=(tr[rt].l+tr[rt].r)>>1 ;
	if (r<=mid) return qmax(ls,l,r) ;
	if (l>mid) return qmax(rs,l,r) ;
	return max(qmax(ls,l,r),qmax(rs,l,r)) ;
} 
int qsum(int rt,int l,int r){
	if (l<=tr[rt].l && tr[rt].r<=r) return tr[rt].sum ;
	int mid=(tr[rt].l+tr[rt].r)>>1 ;
	if (r<=mid) return qsum(ls,l,r) ;
	if (l>mid) return qsum(rs,l,r) ;
	return qsum(ls,l,r)+qsum(rs,l,r) ;
}
int linkmax(int x,int y){
	int fx=top[x],fy=top[y],ans=-inf ;
	while(fx!=fy){
		if (dep[fx]<dep[fy]){
			swap(x,y) ;
			swap(fx,fy) ;
		} 
		ans=max(ans,qmax(1,dfn[fx],dfn[x])) ;
		x=fa[fx] ;
		fx=top[x] ;
	}
	if (dep[x]>dep[y]) swap(x,y) ;//在同一条重链上 
	ans=max(ans,qmax(1,dfn[x],dfn[y])) ;
	return ans ;
}
int linksum(int x,int y){
	int fx=top[x],fy=top[y],ans=0 ;
	while(fx!=fy){
		if (dep[fx]<dep[fy]){
			swap(x,y) ;
			swap(fx,fy) ;
		}
		ans+=qsum(1,dfn[fx],dfn[x]) ;
		x=fa[fx] ;
		fx=top[x] ;
	}
	if (dep[x]>dep[y]) swap(x,y) ;
	ans+=qsum(1,dfn[x],dfn[y]) ;
	return ans ;
}
int main(){
	scanf("%d",&n) ;
	for (int i=1;i<n;i++){
		scanf("%d%d",&x,&y) ;
		e[x].push_back(y) ; 
		e[y].push_back(x) ;
	}
	for (int i=1;i<=n;i++) scanf("%d",&a[i]) ;
	fa[1]=0;dep[1]=1; 
	dfs1(1) ;
	dfs2(1,1) ;
	build(1,1,n) ;
	scanf("%d",&Q) ;
	while(Q--){
		scanf("%s%d%d",op,&u,&t) ;
		if (op[0]=='C') a[u]=t,modify(1,dfn[u]) ;
		else if (op[1]=='M') printf("%d\n",linkmax(u,t)) ;
		else if (op[1]=='S') printf("%d\n",linksum(u,t)) ;
	}
}

一道双倍经验题:【模板】树链剖分

// luogu-judger-enable-o2
#include 
using namespace std;
const int N = 100010 ;
#define int long long
struct edge{
    int to,next ;
}e[N<<1];
int head[N],f[N],dep[N],size[N],son[N],rk[N],top[N],dfn[N];
int a[N];
//f[i]:i的父亲,dep[i]:i的深度,size[i]:i的子树大小,son[i]:重儿子 ,rk[i]:i的dfs值,与dfn相反
//top[i]:i所在链的顶端,dfn[i]:dfs序,时间戳 
int n,m,rt,tot,cnt;
int p,r ;
inline void add(int x,int y){
    e[++cnt].to=y;
    e[cnt].next=head[x] ;
    head[x]=cnt ;
}
void dfs1(int rt,int fa,int depth){ //主要处理深度,父亲和儿子 
    f[rt]=fa;dep[rt]=depth;size[rt]=1;//一些初始化 
    for (int i=head[rt];i;i=e[i].next){
        int to=e[i].to ;
        if (to==fa) continue ;//保证不是父亲 
        dfs1(to,rt,depth+1) ;
        size[rt]+=size[to] ;//rt的大小+子树的大小 
        if (size[son[rt]]<size[to]) son[rt]=to ;//改变重儿子 
    }
    return ;
}
void dfs2(int rt,int t){ //主要处理链,dfs序 
    top[rt]=t;dfn[rt]=++cnt;rk[cnt]=rt;//初始化
    if (!son[rt]) return ;//该点没有重儿子 
    dfs2(son[rt],t) ;//rt的重儿子也是和rt一样处于以t为顶端的重链 
    for (int i=head[rt];i;i=e[i].next){
        int to=e[i].to ;
        if (to!=f[rt] && to!=son[rt]) dfs2(to,to) ;//一个点位于轻链底端,那么它的top必然是它本身
    }
    return ;
}
struct seg{ //线段树 
    int ls,rs,lazy,l,r;
    int sum ;
}tree[N<<1];
inline void pushup(int rt){
    tree[rt].sum=(tree[tree[rt].ls].sum+tree[tree[rt].rs].sum+
    tree[rt].lazy*(tree[rt].r-tree[rt].l+1))%p;
    return ;
}
void build(int ll,int rr,int rt){ //create
    if (ll==rr){
        tree[rt].l=tree[rt].r=ll ;
        tree[rt].sum=a[rk[ll]] ;
        return ;
    }
    else {
        int mid=(ll+rr)>>1;
        tree[rt].ls=cnt++ ;
        tree[rt].rs=cnt++ ;
        build(ll,mid,tree[rt].ls) ;
        build(mid+1,rr,tree[rt].rs) ;
        tree[rt].l=tree[tree[rt].ls].l ;
        tree[rt].r=tree[tree[rt].rs].r ;
        pushup(rt) ;
    }
    return ;
}
void update(int l,int r,int rt,int c){ //l~r +c 
    if (l<=tree[rt].l && tree[rt].r<=r) {
        tree[rt].sum=(tree[rt].sum+c*(tree[rt].r-tree[rt].l+1))%p ;
        tree[rt].lazy=(tree[rt].lazy+c)%p ;
    }
    else {
        int mid=(tree[rt].l+tree[rt].r)>>1 ;
        if (l<=mid) update(l,r,tree[rt].ls,c) ;
        if (mid<r) update(l,r,tree[rt].rs,c) ;
        pushup(rt) ;
    }
    return ;
}
int query(int l,int r,int rt){
    if (l<=tree[rt].l && tree[rt].r<=r) return tree[rt].sum ;
    int tot=(tree[rt].lazy*(min(r,tree[rt].r)-max(l,tree[rt].l)+1)%p)%p ;//初始值
    int mid=(tree[rt].l+tree[rt].r)>>1 ;
    if (l<=mid) tot=(tot+query(l,r,tree[rt].ls))%p ;
    if (mid<r)  tot=(tot+query(l,r,tree[rt].rs))%p ;
    return tot%p ; 
} 
inline int sum(int x,int y){
    int ans=0;
    int fx=top[x],fy=top[y] ;
    while (fx!=fy){
        if (dep[fx]>=dep[fy]) {
            ans=(ans+query(dfn[fx],dfn[x],rt))%p ;
            x=f[fx],fx=top[x] ;
        }
        else {
            ans=(ans+query(dfn[fy],dfn[y],rt))%p ;
            y=f[fy],fy=top[y] ;
        }
    } 
    if (dfn[x]<=dfn[y]) ans=(ans+query(dfn[x],dfn[y],rt))%p ;
    else ans=(ans+query(dfn[y],dfn[x],rt))%p ;
    return ans%p ;
}
inline void UPDATE(int x,int y,int c){
    int fx=top[x],fy=top[y];
    while(fx!=fy)
    {
        if(dep[fx]>=dep[fy]){
            update(dfn[fx],dfn[x],rt,c) ;
            x=f[fx],fx=top[x];
        }
        else {
        	update(dfn[fy],dfn[y],rt,c) ;
            y=f[fy],fy=top[y];
        }
    }
    if (dfn[x]<=dfn[y]) update(dfn[x],dfn[y],rt,c) ;
    else update(dfn[y],dfn[x],rt,c) ;
    return ;
}
main(){
    scanf("%lld%lld%lld%lld",&n,&m,&r,&p) ;
    for (int i=1;i<=n;i++) scanf("%lld",&a[i]) ;
 	for (int i=1;i<n;i++){
        int x,y ;
        scanf("%lld%lld",&x,&y) ;
        add(x,y);
        add(y,x) ;
    }
    cnt=0 ;
    dfs1(r,0,1) ;
    dfs2(r,r) ;
    cnt=0;
    rt=cnt++ ;
    build(1,n,rt);
//	return 0 ; 
    for (int i=1;i<=m;i++){
     //   cout<
        int x,y,op ;
        int z ;
        scanf("%lld",&op);
        if (op==1){
            scanf("%lld%lld%lld",&x,&y,&z) ;
            UPDATE(x,y,z) ; 
        }
        else if (op==2){
            scanf("%lld%lld",&x,&y) ;
            printf("%lld\n",sum(x,y)) ;
        }
        else if (op==3){
            scanf("%lld%lld",&x,&z) ;
            update(dfn[x],dfn[x]+size[x]-1,rt,z) ; 
        }
        else {
            scanf("%lld",&x) ;
            printf("%lld\n",query(dfn[x],dfn[x]+size[x]-1,rt)) ;
        }
    }
}

再来一道例题。

BZOJ 4196 NOI2015 软件包管理器

这个问题是动态删链,动态加链的过程。

这题其实比上题还简单,直接维护线段树即可。

#include 
using namespace std ;
const int N = 100010 ;
const int inf = (1<<30) ;
#define rep(i,a,b) for (int (i)=(a);(i)<=(b);(i)++)
#define REP(i,a,b) for (int (i)=(a);(i)>=(b);(i)--)
#define ls ((rt)<<1)
#define rs ((rt)<<1|1)

typedef long long ll ;

vector <int> e[N] ;
int size[N],fa[N],dep[N],top[N],hson[N],dfn[N] ;
int n,m,tot,x ;
char op[20] ;
struct node{
	int l,r,sum,lazy ;
}tr[N<<2];

void dfs1(int rt){
	size[rt]=1 ;
	for (int i=0;i<e[rt].size();i++){
		int to=e[rt][i] ;
		if (to==fa[rt]) continue ;
		fa[to]=rt;dep[to]=dep[rt]+1; 
		dfs1(to) ;
		size[rt]+=size[to] ;
		if (!hson[rt] || size[to]>size[hson[rt]]) hson[rt]=to ;
	}
}

void dfs2(int rt,int t){
	top[rt]=t;
	dfn[rt]=++tot;
	if (!hson[rt]) return ;
	dfs2(hson[rt],t) ;
	for (int i=0;i<e[rt].size();i++){
		int to=e[rt][i] ;
		if (to==fa[rt] || to==hson[rt]) continue ;
		dfs2(to,to) ;
	}
}
inline void pushup(int rt){
	tr[rt].sum=tr[ls].sum+tr[rs].sum ;
}
void build(int rt,int l,int r){
	tr[rt].l=l;tr[rt].r=r,tr[rt].lazy=-1;
	if (l==r) return ;
	int mid=(l+r)>>1 ;
	build(ls,l,mid) ;
	build(rs,mid+1,r) ;
	pushup(rt) ;
}
inline void pushdown(int rt){ 
	if (tr[rt].lazy==-1) return ;
	tr[ls].sum=tr[rt].lazy*(tr[ls].r-tr[ls].l+1) ;
	tr[rs].sum=tr[rt].lazy*(tr[rs].r-tr[rs].l+1) ;
	tr[ls].lazy=tr[rt].lazy ;
	tr[rs].lazy=tr[rt].lazy ;
	tr[rt].lazy=-1 ;
}
void modify(int rt,int l,int r,int c){  
	if (l<=tr[rt].l && tr[rt].r<=r) {
		tr[rt].sum=c*(tr[rt].r-tr[rt].l+1) ;
		tr[rt].lazy=c ;
		return ;
	}
	pushdown(rt) ; 
	int mid=(tr[rt].l+tr[rt].r)>>1 ;
	if (l<=mid) modify(ls,l,r,c) ;
	if (r>mid) modify(rs,l,r,c) ;
	pushup(rt) ;
}
int query(int rt,int l,int r){
	if (l<=tr[rt].l && tr[rt].r<=r) return tr[rt].sum ;
	pushdown(rt) ;
	int res=0,mid=(tr[rt].l+tr[rt].r)>>1 ;
	if (l<=mid) res+=query(ls,l,r) ;
	if (r>mid) res+=query(rs,l,r) ;
	return res ;
}
int sum(int x){
	int ans=0,fx=top[x] ;
	while (fx!=1){
		ans+=dfn[x]-dfn[fx]-query(1,dfn[fx],dfn[x])+1 ;
		modify(1,dfn[fx],dfn[x],1) ;
		x=fa[fx] ;
		fx=top[x] ;
	}
	ans+=dfn[x]-dfn[1]-query(1,dfn[1],dfn[x])+1 ;
	modify(1,dfn[1],dfn[x],1) ;
	return ans ;
}
int main(){
	scanf("%d",&n) ;
	for (int i=2;i<=n;i++) {
		scanf("%d",&x) ;
		x++ ; 
		e[x].push_back(i) ;
		e[i].push_back(x) ;
	}
	fa[1]=0;dep[1]=1 ;
	dfs1(1) ;
	dfs2(1,1) ;
	build(1,1,n) ;
	scanf("%d",&m) ;
	for (int i=1;i<=m;i++){
		scanf("%s%d",&op,&x) ;
		x++ ;
		if (op[0]=='i') printf("%d\n",sum(x)) ;
		else {
			printf("%d\n",query(1,dfn[x],dfn[x]+size[x]-1)) ;
			modify(1,dfn[x],dfn[x]+size[x]-1,0) ;
		}
	}
}

都理解了, 来几道习题练练

Aragorn’s Story

[HAOI2015]树上操作

月下“毛景树”

蒟蒻第一次写关于算法的博客,有问题或建议请及时提出,在评论区中发表,博主将及时更改,谢谢阅读!

你可能感兴趣的:(树链剖分,最近公共祖先LCA,算法总结)