树链剖分学习笔记

树链剖分主要解决在树上某条路径上或某棵子树的sum与最值

入门树链剖分,最重要的概念是重儿子,用son[]记录。son[i]代表的是以i为根,节点最多子树根的编号。通过son,我们将树中的边分为两种,轻边和重边。重边是每一个点与它重儿子的连边。将链续的重边串起来,构成了一条条重链。对于每个点,它一定在某条重链上,特殊的,一个单独的点也可以是一条重链。对于一条重链上的点,他们的dfs序是连续的,对与一颗子树上的点,他们的dfs序也是连续的,于是我们将树上的点转化成一个区间,在区间用线段树上求解或修改。

树链剖分的核心便是如何将树剖分成若干链。

在此之前,先了解以下七个参数的意义。

fa[x] x的父亲节点编号

dep[x] x的深度

size[x] 以x为根子树的节点,用来求son

seg[x] 以son为基础的dfs2序,即其在线段树上的编号

rev[x] 用来将线段树上的编号转化为原编号。

son[x] 记录重儿子  

top[x] 记录x所在重链dep最小的点

树链剖分需要的前置知识有DFS+LCA+线段树

我们在两次dfs中求出上述七个参数

dfs1:求出dep,f,size,son。
inline void dfs1(int u,int f){
    int i,v;
    size[u]=1;
    fa[u]=f;
    dep[u]=dep[f]+1;
    for(i=fir[u];v=to[i],i;i=nex[i]){
        if(v!=f){
            dfs1(v,u);
            size[u]+=size[v];
            if(size[v]>size[son[u]])//更新重儿子
                son[u]=v;
        }
    }
}

dfs2:求出rev,seg,top。

inline void dfs2(int u,int f){
    int i,v;
    if(son[u]){//优先遍历重儿子。
        seg[son[u]]=++tim;
        rev[tim]=son[u];
        top[son[u]]=top[u];//重儿子的top,就是u的top。
        dfs2(son[u],u);
    }
    for(i=fir[u];v=to[i],i;i=nex[i])
        if(!top[v]){//访问轻边
            seg[v]=++tim;
            rev[tim]=v;
            top[v]=v;//轻边单独开了一条链,top是本身
            dfs2(v,u);
        }
}

两遍dfs就把整棵树划分为若干条链,剩下的就交给线段树解决了。

首先是建树

inline void build(int k,int l,int r){
    if(l==r){
        sum[k]=ma[k]=w[l];//w是每个点的全值
        return;
    }
    int mid=l+r>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
    ma[k]=max(ma[k<<1],ma[k<<1|1]);
    sum[k]=sum[k<<1]+sum[k<<1|1];
}

线段树的查询和修改类似。

inline void change(int k,int l,int r,int val,int pos){
//pos是当前要修改的的位置,val是改变后的值
if(l>pos||r<pos) return; if(l==r&&l==pos){ ma[k]=sum[k]=val; return; } int mid=l+r>>1; change(k<<1,l,mid,val,pos); change(k<<1|1,mid+1,r,val,pos); sum[k]=sum[k<<1]+sum[k<<1|1]; ma[k]=max(ma[k<<1],ma[k<<1|1]); } inline void query(int k,int x,int y,int l,int r){
//l~r为需要修改的区间
if(x>r||y<l) return; if(x>=l&&y<=r){ SUM+=sum[k]; MAX=max(ma[k],MAX); return; } int mid=x+y>>1; query(k<<1,x,mid,l,r); query(k<<1|1,mid+1,y,l,r); }

最后,我们只需要知道只需要知道哪些点对这条路径有贡献,统计他们的贡献即可。

inline void ask(int x,int y){

inline void ask(int x,int y){  
int fx=top[x],fy=top[y]; while(fx!=fy){//如果他们不在同一重链上 if(dep[fx]<dep[fy]) swap(x,y),swap(fx,fy);//选取深度大的那一条, query(1,1,tim,seg[fx],seg[x]);//注意要将原编号转化为dfs序编号 x=fa[x],fx=top[x]; }
  //如果他们在一条链上了,再统计x~y路径的贡献
if(dep[x]>dep[y]) swap(x,y);//保证x的编号小等于y query(1,1,tim,seg[x],seg[y]); }

下面附上一道模板题

https://www.lydsy.com/JudgeOnline/problem.php?id=1036

树链剖分学习笔记_第1张图片

#include
#include
#include
#define max(x,y) (x>y?x:y)
#define N 100000
using namespace std;
int n,m,tot,tim,SUM,MAX;
int fir[N],to[N],nex[N];
int seg[N],rev[N],size[N],son[N],dep[N],top[N],fa[N];
int sum[N],ma[N],w[N];
inline void r(int &x){
    bool sign=1;
    x=0;
    char ch=getchar();
    while(ch<'0'||ch>'9') ch=getchar();
    if(ch=='-') sign=0,ch=getchar();
    while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    x=sign?x:-x;
}
inline void add(int x,int y){
    to[++tot]=y,nex[tot]=fir[x],fir[x]=tot;
    to[++tot]=x,nex[tot]=fir[y],fir[y]=tot;
}
inline void dfs1(int u,int f){
    int i,v;
    size[u]=1;
    fa[u]=f;
    dep[u]=dep[f]+1;
    for(i=fir[u];v=to[i],i;i=nex[i]){
        if(v!=f){
            dfs1(v,u);
            size[u]+=size[v];
            if(size[v]>size[son[u]])
                son[u]=v;
        }
    }
}
inline void dfs2(int u,int f){
    int i,v;
    if(son[u]){
        seg[son[u]]=++tim;
        rev[tim]=son[u];
        top[son[u]]=top[u];
        dfs2(son[u],u);
    }
    for(i=fir[u];v=to[i],i;i=nex[i])
        if(!top[v]){
            seg[v]=++tim;
            rev[tim]=v;
            top[v]=v;
            dfs2(v,u);
        }
}
inline void build(int k,int l,int r){
    if(l==r){
        sum[k]=ma[k]=w[l];
        return;
    }
    int mid=l+r>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
    ma[k]=max(ma[k<<1],ma[k<<1|1]);
    sum[k]=sum[k<<1]+sum[k<<1|1];
}
inline void change(int k,int l,int r,int val,int pos){
    if(l>pos||r<pos)
        return;
    if(l==r&&l==pos){
        ma[k]=sum[k]=val;
        return;
    }
    int mid=l+r>>1;
    change(k<<1,l,mid,val,pos);
    change(k<<1|1,mid+1,r,val,pos);
    sum[k]=sum[k<<1]+sum[k<<1|1];
    ma[k]=max(ma[k<<1],ma[k<<1|1]);
}
inline void query(int k,int x,int y,int l,int r){
    if(x>r||y<l)
        return;
    if(x>=l&&y<=r){
        SUM+=sum[k];
        MAX=max(ma[k],MAX);
        return;
    }
    int mid=x+y>>1;
    query(k<<1,x,mid,l,r);
    query(k<<1|1,mid+1,y,l,r);
}
inline void ask(int x,int y){
    int fx=top[x],fy=top[y];
    while(fx!=fy){
        if(dep[fx]<dep[fy]) swap(x,y),swap(fx,fy);
        query(1,1,tim,seg[fx],seg[x]);
        x=fa[x],fx=top[x];
    }
    if(dep[x]>dep[y]) swap(x,y);
    query(1,1,tim,seg[x],seg[y]);
}
int main()
{
    int i,j,x,y;
    char op[10];
    r(n);
    for(i=1;i){
        r(x),r(y);
        add(x,y);
    }
    for(i=1;i<=n;i++)
        r(w[i]);
    tim=seg[1]=top[1]=rev[1]=1;
    dfs1(1,0);
    dfs2(1,0);
    build(1,1,tim);
    r(m);
    for(i=1;i<=m;i++){
        scanf("%s",op);
        r(x),r(y);
        SUM=0;
        MAX=-N;
        switch(op[1]){
            case 'M':{
                ask(x,y);
                printf("%d\n",MAX);
                break;
            }
            case 'S':{
                ask(x,y);
                printf("%d\n",SUM);
                break;
            }
            case 'H':{
                change(1,1,tim,y,seg[x]);
                break;
            }
        }
    }
}
View Code

2019-09-04

你可能感兴趣的:(树链剖分学习笔记)