- 树链剖分
- 简介
- 实现
- 模板
- Luogu3384 模板树链剖分
- 应用
- BZOJ1036 ZJOI2008 树的统计
- BZOJ4034 HAOI2015 树上操作
- BZOJ2243 SDOI2011 染色
- BZOJ3531 SDOI2014 旅行
- BZOJ3626 LNOI2014 LCA
树链剖分适用于一些复杂的题目,可以较为充分获取树上的信息,将其转换为线性结构后可以很方便的使用线性数据结构进行处理。
那么,树链剖分试讲一棵树如何转化为一条链的呢?之前有前序、中序、后序遍历,而树链剖分与前三种方式相似但不相同。
首先,有两个重要的概念:
对于一个节点,他的子节点的子树中节点数最多的是他的重儿子,相对地,其他节点成为此节点的轻儿子。
对每个节点,我们遍历的时候首先到它的重儿子,在重儿子递归回溯之后再遍历轻儿子。通过这种方法,我们可以得到一条链,我们就可以使用各种线性数据结构来维护树上的的信息。而这些信息的更改与查询方式无非两种:
对于第一种方式处理办法十分简单,因为我们根据上述方法得到的链可以保证一个子树中的节点在链上是连续的,那么,其实我们的目标区间其实就是以子树的根为起点,以子树大小为长度的区间。
而对于后一种方式,我们发现其在链中并非连续的,那么我们考虑将其拆成数个连续的区间进行处理,具体方法如下:
具体实现方面,我们需要两次 DFS:
已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z
操作2: 格式: 2 x y 表示求树从 x 到 y 结点最短路径上所有节>点的值之和
操作3: 格式: 3 x z 表示将以 x 为根节点的子树内所有节点值都加上 z
操作4: 格式: 4 x 表示求以 x 为根节点的子树内所有节点值之和
#include
#include
typedef long long ll;
const int MAXN=1e5+5;
int N,M,R,P;
int na[MAXN];
struct E{int next,to;} e[MAXN<<1];int ecnt,G[MAXN];
void addEdge(int u,int v){e[++ecnt]=(E){G[u],v};G[u]=ecnt;}
void addEdge2(int u,int v){addEdge(u,v),addEdge(v,u);}
int son[MAXN],fa[MAXN],sz[MAXN],dpt[MAXN];
void dfs1(int u)
{
sz[u]=1;
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]) continue;
fa[v]=u,dpt[v]=dpt[u]+1;
dfs1(v);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
int w[MAXN],wcnt,tp[MAXN];
void dfs2(int u)
{
int v=son[u];if(!v) return;
tp[v]=tp[u],w[v]=++wcnt;
dfs2(v);
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
tp[v]=v,w[v]=++wcnt;
dfs2(v);
}
}
const int rt=1;
struct SEGTN{int l,r,lc,rc;int sum,tag;} t[MAXN<<1];int tcnt=rt;
void upd(int o)
{t[o].sum=((ll)t[t[o].lc].sum+(ll)t[t[o].rc].sum)%P;}
void pushdown(int o)
{
SEGTN &lc=t[t[o].lc],&rc=t[t[o].rc];int &tag=t[o].tag;
if(tag)
{
lc.sum=((ll)lc.sum+((ll)(lc.r-lc.l+1)*(ll)tag)%P)%P,lc.tag=((ll)lc.tag+(ll)tag)%P;
rc.sum=((ll)rc.sum+((ll)(rc.r-rc.l+1)*(ll)tag)%P)%P,rc.tag=((ll)rc.tag+(ll)tag)%P;
tag=0;
}
}
void buildTree(int o,int l,int r)
{
t[o].l=l,t[o].r=r;
if(l==r) {t[o].sum=na[l];return;}
int mid=(l+r)>>1;
t[o].lc=++tcnt;buildTree(t[o].lc,l,mid);
t[o].rc=++tcnt;buildTree(t[o].rc,mid+1,r);
upd(o);
}
void chDet(int o,int l,int r,int v)
{
if(l<=t[o].l&&t[o].r<=r)
{
t[o].sum=((ll)t[o].sum+((((ll)t[o].r-(ll)t[o].l+1LL)%P)*(ll)v)%P)%P;
t[o].tag=((ll)t[o].tag+(ll)v)%P;
return;
}
int mid=(t[o].l+t[o].r)>>1;
pushdown(o);
if(l<=mid) chDet(t[o].lc,l,r,v);
if(r>mid) chDet(t[o].rc,l,r,v);
upd(o);
}
int calSum(int o,int l,int r)
{
if(l<=t[o].l&&t[o].r<=r) return t[o].sum;
int mid=(t[o].l+t[o].r)>>1;
pushdown(o);
int res=0;
if(l<=mid) res=((ll)res+calSum(t[o].lc,l,r))%P;
if(r>mid) res=((ll)res+calSum(t[o].rc,l,r))%P;
return res;
}
void chainDet(int x,int y,int v)
{
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]y]]) std::swap(x,y);
chDet(rt,w[tp[x]],w[x],v);
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
chDet(rt,w[x],w[y],v);
}
int chainSum(int x,int y)
{
int res=0;
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]y]]) std::swap(x,y);
res=((ll)res+calSum(rt,w[tp[x]],w[x]))%P;
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
res=((ll)res+calSum(rt,w[x],w[y]))%P;
return res;
}
void stDet(int x,int v){chDet(rt,w[x],w[x]+sz[x]-1,v);}
int stSum(int x){return calSum(rt,w[x],w[x]+sz[x]-1);}
int tmp[MAXN];
int main()
{
int i;
scanf("%d%d%d%d",&N,&M,&R,&P);
for(i=1;i<=N;i++) scanf("%d",&tmp[i]);
int u,v;
for(i=1;i"%d%d",&u,&v),addEdge2(u,v);
dpt[R]=1;dfs1(R);
tp[R]=R,w[R]=++wcnt;dfs2(R);
for(i=1;i<=N;i++) na[w[i]]=tmp[i];
buildTree(rt,1,N);
while(M--)
{
int opt,x,y,z;
scanf("%d%d",&opt,&x);
if(opt==1) scanf("%d%d",&y,&z),chainDet(x,y,z);
else if(opt==2) scanf("%d",&y),printf("%d\n",chainSum(x,y));
else if(opt==3) scanf("%d",&y),stDet(x,y);
else if(opt==4) printf("%d\n",stSum(x));
}
return 0;
}
一棵树上有 n 个节点,编号分别为 1 到 n ,每个节点都有一个权值 w 。我们将以下面的形式来要求你对这棵树完成一些操作:
- CHANGE u t:把结点 u 的权值改为 t
- QMAX u v:询问从点 u 到点 v 的路径上的节点的最大权值
- QSUM u v:询问从点 u 到点 v 的路径上的节点的权值和
注意:从点 u 到点 v 的路径上的节点包括 u 和 v 本身1≤n≤30000 , 0≤q≤200000 ;每个节点的权值 w 在 -30000 到 30000 之间。
/**************************************************************
Problem: 1036
User: zhangche0526
Language: C++
Result: Accepted
Time:2880 ms
Memory:6844 kb
****************************************************************/
#include
#include
typedef long long ll;
const int MAXN=3e4+5,INF=~0U>>1;
int N,Q;
int na[MAXN];
struct E{int next,to;} e[MAXN<<1];int ecnt,G[MAXN];
void addEdge(int u,int v){e[++ecnt]=(E){G[u],v};G[u]=ecnt;}
void addEdge2(int u,int v){addEdge(u,v),addEdge(v,u);}
int son[MAXN],fa[MAXN],sz[MAXN],dpt[MAXN];
void dfs1(int u)
{
sz[u]=1;
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]) continue;
fa[v]=u,dpt[v]=dpt[u]+1;
dfs1(v);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
int w[MAXN],wcnt,tp[MAXN];
void dfs2(int u)
{
int v=son[u];if(!v) return;
tp[v]=tp[u],w[v]=++wcnt;
dfs2(v);
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
tp[v]=v,w[v]=++wcnt;
dfs2(v);
}
}
const int rt=1;
struct SEGTN{int l,r,lc,rc;int sum,mx;} t[MAXN<<2];int tcnt=rt;
void upd(int o)
{
t[o].sum=t[t[o].lc].sum+t[t[o].rc].sum;
t[o].mx=std::max(t[t[o].lc].mx,t[t[o].rc].mx);
}
void buildTree(int o,int l,int r)
{
t[o].l=l,t[o].r=r;
if(l==r) {t[o].sum=t[o].mx=na[l];return;}
int mid=(l+r)>>1;
t[o].lc=++tcnt;buildTree(t[o].lc,l,mid);
t[o].rc=++tcnt;buildTree(t[o].rc,mid+1,r);
upd(o);
}
void change(int o,int p,int v)
{
if(t[o].l==t[o].r&&t[o].l==p)
{
t[o].sum=t[o].mx=v;
return;
}
int mid=(t[o].l+t[o].r)>>1;
if(p<=mid) change(t[o].lc,p,v);
else change(t[o].rc,p,v);
upd(o);
}
int calSum(int o,int l,int r)
{
if(l<=t[o].l&&t[o].r<=r) return t[o].sum;
int mid=(t[o].l+t[o].r)>>1;
int res=0;
if(l<=mid) res+=calSum(t[o].lc,l,r);
if(r>mid) res+=calSum(t[o].rc,l,r);
return res;
}
int calMx(int o,int l,int r)
{
if(l<=t[o].l&&t[o].r<=r) return t[o].mx;
int mid=(t[o].l+t[o].r)>>1;
int lmx=-INF,rmx=-INF;
if(l<=mid) lmx=calMx(t[o].lc,l,r);
if(r>mid) rmx=calMx(t[o].rc,l,r);
return std::max(lmx,rmx);
}
int chainSum(int x,int y)
{
int res=0;
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]std::swap(x,y);
res+=calSum(rt,w[tp[x]],w[x]);
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
res+=calSum(rt,w[x],w[y]);
return res;
}
int chainMx(int x,int y)
{
int res=-INF;
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]std::swap(x,y);
res=std::max(res,calMx(rt,w[tp[x]],w[x]));
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
res=std::max(res,calMx(rt,w[x],w[y]));
return res;
}
char opt[10];
int main()
{
int i;
scanf("%d",&N);
int u,v;
for(i=1;iscanf("%d%d",&u,&v),addEdge2(u,v);
dpt[1]=1;dfs1(1);
tp[1]=1,w[1]=++wcnt;dfs2(1);
for(i=1;i<=N;i++) scanf("%d",&na[w[i]]);
buildTree(rt,1,N);
scanf("%d",&Q);
while(Q--)
{
scanf("%s%d%d",opt,&u,&v);
if(opt[0]=='C') change(rt,w[u],v);
else
{
if(opt[1]=='S') printf("%d\n",chainSum(u,v));
else printf("%d\n",chainMx(u,v));
}
}
return 0;
}
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个操作,分为三种:
- 把某个节点 x 的点权增加 a ;
- 把某个节点 x 为根的子树中所有点的点权都增加 a ;
- 询问某个节点 x 到根的路径中所有点的点权和。
N,M≤100000 ,且所有输入数据的绝对值都不会超过 106 。
#include
#include
typedef long long ll;
const int MAXN=1e5+5;
int N,M;
int na[MAXN];
struct E{int next,to;} e[MAXN<<1];int ecnt,G[MAXN];
void addEdge(int u,int v){e[++ecnt]=(E){G[u],v};G[u]=ecnt;}
void addEdge2(int u,int v){addEdge(u,v),addEdge(v,u);}
int son[MAXN],fa[MAXN],sz[MAXN],dpt[MAXN];
void dfs1(int u)
{
sz[u]=1;
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]) continue;
fa[v]=u,dpt[v]=dpt[u]+1;
dfs1(v);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
int w[MAXN],wcnt,tp[MAXN];
void dfs2(int u)
{
int v=son[u];if(!v) return;
tp[v]=tp[u],w[v]=++wcnt;
dfs2(v);
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
tp[v]=v,w[v]=++wcnt;
dfs2(v);
}
}
const int rt=1;
struct SEGTN{int l,r,lc,rc;ll sum,tag;} t[MAXN<<1];int tcnt=rt;
void upd(int o){t[o].sum=t[t[o].lc].sum+t[t[o].rc].sum;}
void pushdown(int o)
{
SEGTN &lc=t[t[o].lc],&rc=t[t[o].rc];ll &tag=t[o].tag;
if(tag)
{
lc.sum+=(lc.r-lc.l+1LL)*tag,lc.tag+=tag;
rc.sum+=(rc.r-rc.l+1LL)*tag,rc.tag+=tag;
tag=0;
}
}
void buildTree(int o,int l,int r)
{
t[o].l=l,t[o].r=r;
if(l==r) {t[o].sum=na[l];return;}
int mid=l+r>>1;
t[o].lc=++tcnt;buildTree(t[o].lc,l,mid);
t[o].rc=++tcnt;buildTree(t[o].rc,mid+1,r);
upd(o);
}
void chDet(int o,int l,int r,int v)
{
if(l<=t[o].l&&t[o].r<=r)
{
t[o].sum+=(t[o].r-t[o].l+1LL)*(ll)v;
t[o].tag+=(ll)v;
return;
}
int mid=t[o].l+t[o].r>>1;
pushdown(o);
if(l<=mid) chDet(t[o].lc,l,r,v);
if(r>mid) chDet(t[o].rc,l,r,v);
upd(o);
}
ll calSum(int o,int l,int r)
{
if(l<=t[o].l&&t[o].r<=r) return t[o].sum;
int mid=(t[o].l+t[o].r)>>1;
pushdown(o);
ll res=0;
if(l<=mid) res+=calSum(t[o].lc,l,r);
if(r>mid) res+=calSum(t[o].rc,l,r);
return res;
}
void ptDet(int x,int v){chDet(rt,w[x],w[x],v);}
void stDet(int x,int v){chDet(rt,w[x],w[x]+sz[x]-1,v);}
ll chainSum(int x,int y)
{
ll res=0;
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]std::swap(x,y);
res+=calSum(rt,w[tp[x]],w[x]);
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
res+=calSum(rt,w[x],w[y]);
return res;
}
int tmp[MAXN];
int main()
{
int i;
scanf("%d%d",&N,&M);
for(i=1;i<=N;i++) scanf("%d",&tmp[i]);
int u,v;
for(i=1;iscanf("%d%d",&u,&v),addEdge2(u,v);
dpt[1]=1;dfs1(1);
tp[1]=1,w[1]=++wcnt;dfs2(1);
for(i=1;i<=N;i++) na[w[i]]=tmp[i];
buildTree(rt,1,N);
while(M--)
{
int opt,x,y,z;
scanf("%d%d",&opt,&x);
if(opt==1) scanf("%d",&y),ptDet(x,y);
else if(opt==2) scanf("%d",&y),stDet(x,y);
else if(opt==3) printf("%lld\n",chainSum(x,1));
}
return 0;
}
给定一棵有 n 个节点的无根树和 m 个操作,操作有两类:
- 路径上所有点都染成颜色c;
- 询问路径上的颜色段数量(连续相同颜色被认为是同一段).
N≤105,M≤105,c∈[1,109]
很明显的树剖,我们先考虑链上的情况,可以对每个区间记录一下左右端点的颜色和区间的颜色种类数,然后神奇合并,用树剖搞一搞,注意合并即可。
#include
#include
typedef long long ll;
const int MAXN=1e5+5;
int N,M;
int na[MAXN];
struct E{int next,to;} e[MAXN<<1];int ecnt,G[MAXN];
void addEdge(int u,int v){e[++ecnt]=(E){G[u],v};G[u]=ecnt;}
void addEdge2(int u,int v){addEdge(u,v),addEdge(v,u);}
int son[MAXN],fa[MAXN],sz[MAXN],dpt[MAXN];
void dfs1(int u)
{
sz[u]=1;
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]) continue;
fa[v]=u,dpt[v]=dpt[u]+1;
dfs1(v);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
int w[MAXN],wcnt,tp[MAXN];
void dfs2(int u)
{
int v=son[u];if(!v) return;
tp[v]=tp[u],w[v]=++wcnt;
dfs2(v);
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
tp[v]=v,w[v]=++wcnt;
dfs2(v);
}
}
const int rt=1;
struct SEGTN
{
int l,r,lc,rc;
int sum,tag,clrL,clrR;
SEGTN(){tag=clrL=clrR=-1;}
} t[MAXN<<2];
int tcnt=rt;
int clrL,clrR;
void upd(int o)
{
t[o].sum=t[t[o].lc].sum+t[t[o].rc].sum;
if(t[t[o].lc].clrR==t[t[o].rc].clrL) t[o].sum--;
t[o].clrL=t[t[o].lc].clrL,t[o].clrR=t[t[o].rc].clrR;
}
void pushdown(int o)
{
SEGTN &lc=t[t[o].lc],&rc=t[t[o].rc];int &tag=t[o].tag;
if(~tag)
{
lc.sum=1,lc.clrL=lc.clrR=tag,lc.tag=tag;
rc.sum=1,rc.clrL=rc.clrR=tag,rc.tag=tag;
tag=-1;
}
}
void buildTree(int o,int l,int r)
{
t[o].l=l,t[o].r=r;
if(l==r)
{
t[o].sum=1,t[o].clrL=t[o].clrR=na[l];
return;
}
int mid=l+r>>1;
t[o].lc=++tcnt;buildTree(t[o].lc,l,mid);
t[o].rc=++tcnt;buildTree(t[o].rc,mid+1,r);
upd(o);
}
void chCh(int o,int l,int r,int v)
{
if(l<=t[o].l&&t[o].r<=r)
{
t[o].sum=1,t[o].clrL=t[o].clrR=v,t[o].tag=v;
return;
}
int mid=t[o].l+t[o].r>>1;
pushdown(o);
if(l<=mid) chCh(t[o].lc,l,r,v);
if(r>mid) chCh(t[o].rc,l,r,v);
upd(o);
}
int calSum(int o,int l,int r)
{
if(t[o].l==l) clrL=t[o].clrL;
if(t[o].r==r) clrR=t[o].clrR;
if(l<=t[o].l&&t[o].r<=r) return t[o].sum;
int mid=t[o].l+t[o].r>>1;
pushdown(o);
int res=0,resL=0,resR=0;
if(l<=mid) resL=calSum(t[o].lc,l,r);
if(r>mid) resR=calSum(t[o].rc,l,r);
res=resL+resR;
if(resL&&resR&&t[t[o].lc].clrR==t[t[o].rc].clrL)
res--;
return res;
}
void chainCh(int x,int y,int v)
{
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]std::swap(x,y);
chCh(rt,w[tp[x]],w[x],v);
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
chCh(rt,w[x],w[y],v);
}
int chainSum(int x,int y)
{
int res=0,clrX=-1,clrY=-1;clrL=clrR=-1;
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]std::swap(x,y),std::swap(clrX,clrY);
res+=calSum(rt,w[tp[x]],w[x]);
if(clrR==clrX) res--;
clrX=clrL;
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y),std::swap(clrX,clrY);
res+=calSum(rt,w[x],w[y]);
if(clrX==clrL) res--;
if(clrY==clrR) res--;
return res;
}
int tmp[MAXN];
int main()
{
int i;
scanf("%d%d",&N,&M);
for(i=1;i<=N;i++) scanf("%d",&tmp[i]);
int u,v;
for(i=1;iscanf("%d%d",&u,&v),addEdge2(u,v);
dpt[1]=1;dfs1(1);
tp[1]=1,w[1]=++wcnt;dfs2(1);
for(i=1;i<=N;i++) na[w[i]]=tmp[i];
buildTree(rt,1,N);
while(M--)
{
char opt[10];int x,y,z;
scanf("%s%d%d",opt,&x,&y);
if(opt[0]=='C') scanf("%d",&z),chainCh(x,y,z);
else printf("%d\n",chainSum(x,y));
}
return 0;
}
给出一个 n 个节点的带点权树,每个节点都属于某种类型,操作有更改单点点权或类型,查询两同种类型的点在树上的链中同类型的权值和或最大值。
n,q,c≤105
对每种节点建一棵线段树,注意由于数据范围很大,需要动态开点。
#include
#include
typedef long long ll;
const int MAXN=1e5+5,INF=~0U>>1;
int N,M;
int na[MAXN],typ[MAXN];
struct E{int next,to;} e[MAXN<<1];int ecnt,G[MAXN];
void addEdge(int u,int v){e[++ecnt]=(E){G[u],v};G[u]=ecnt;}
void addEdge2(int u,int v){addEdge(u,v),addEdge(v,u);}
int son[MAXN],fa[MAXN],sz[MAXN],dpt[MAXN];
void dfs1(int u)
{
sz[u]=1;
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]) continue;
fa[v]=u,dpt[v]=dpt[u]+1;
dfs1(v);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
int w[MAXN],wcnt,tp[MAXN];
void dfs2(int u)
{
int v=son[u];if(!v) return;
tp[v]=tp[u],w[v]=++wcnt;
dfs2(v);
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
tp[v]=v,w[v]=++wcnt;
dfs2(v);
}
}
int rt[MAXN];
struct SEGTN{int l,r,lc,rc;int sum,mx;} t[MAXN*40];int tcnt;
void upd(int o)
{
t[o].sum=t[t[o].lc].sum+t[t[o].rc].sum;
t[o].mx=std::max(t[t[o].lc].mx,t[t[o].rc].mx);
}
void chCh(int &o,int l,int r,int p,int v)
{
if(!o) o=++tcnt,t[o].l=l,t[o].r=r;
if(l==r){t[o].sum=t[o].mx=v;return;}
int mid=l+r>>1;
if(p<=mid) chCh(t[o].lc,l,mid,p,v);
else chCh(t[o].rc,mid+1,r,p,v);
upd(o);
}
int calSum(int o,int l,int r)
{
if(!o) return 0;
if(l<=t[o].l&&t[o].r<=r) return t[o].sum;
int mid=t[o].l+t[o].r>>1;
int res=0;
if(l<=mid) res+=calSum(t[o].lc,l,r);
if(r>mid) res+=calSum(t[o].rc,l,r);
return res;
}
int calMx(int o,int l,int r)
{
if(!o) return 0;
if(l<=t[o].l&&t[o].r<=r) return t[o].mx;
int mid=t[o].l+t[o].r>>1;
int lmx=-INF,rmx=-INF;
if(l<=mid) lmx=calMx(t[o].lc,l,r);
if(r>mid) rmx=calMx(t[o].rc,l,r);
return std::max(lmx,rmx);
}
void changeType(int p,int c)
{
chCh(rt[typ[p]],1,N,p,0);
typ[p]=c;
chCh(rt[typ[p]],1,N,p,na[p]);
}
int chainSum(int t,int x,int y)
{
int res=0;
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]std::swap(x,y);
res+=calSum(rt[t],w[tp[x]],w[x]);
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
res+=calSum(rt[t],w[x],w[y]);
return res;
}
int chainMx(int t,int x,int y)
{
int res=-INF;
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]std::swap(x,y);
res=std::max(res,calMx(rt[t],w[tp[x]],w[x]));
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
res=std::max(res,calMx(rt[t],w[x],w[y]));
return res;
}
int tmp[MAXN][2];
int main()
{
int i;
scanf("%d%d",&N,&M);
for(i=1;i<=N;i++) scanf("%d%d",&tmp[i][0],&tmp[i][1]);
int u,v;
for(i=1;iscanf("%d%d",&u,&v),addEdge2(u,v);
dpt[1]=1;dfs1(1);
tp[1]=1,w[1]=++wcnt;dfs2(1);
for(i=1;i<=N;i++) na[w[i]]=tmp[i][0],typ[w[i]]=tmp[i][1];
for(i=1;i<=N;i++)
chCh(rt[typ[w[i]]],1,N,w[i],na[w[i]]);
while(M--)
{
char opt[10];int x,y;
scanf("%s%d%d",opt,&x,&y);
if(opt[0]=='C')
{
if(opt[1]=='C') changeType(w[x],y);
else chCh(rt[typ[w[x]]],1,N,w[x],y),na[w[x]]=y;
}else
{
if(opt[1]=='S') printf("%d\n",chainSum(typ[w[x]],x,y));
else printf("%d\n",chainMx(typ[w[x]],x,y));
}
}
return 0;
}
给定一个 N 个节点的有根树,根的深度为 1 ,有 q 次询问,每次求出一个编号区间内的点与另一给出点的 LCA 的深度之和。
首先考虑一种暴力:每次将给出的点到根的链区间加一,然后其它点的与其 LCA 的深度就是它到根的链上区间和。
可以证明:将区间内的点到根的链区间加一,求给出点到根的链的区间和与上述暴力等价;那么可以依次加入 n 个点,每次将点到根的链区间加一,离线处理询问,对每个端点求它 右端点到根的链的区间和 与 左端点到根的链的区间和 之差即可。
#include
#include
#include
typedef long long ll;
const int MAXN=5e4+5,P=201314;
int N,M;
int na[MAXN];
struct E{int next,to;} e[MAXN<<1];int ecnt,G[MAXN];
void addEdge(int u,int v){e[++ecnt]=(E){G[u],v};G[u]=ecnt;}
void addEdge2(int u,int v){addEdge(u,v),addEdge(v,u);}
int son[MAXN],fa[MAXN],sz[MAXN],dpt[MAXN];
void dfs1(int u)
{
sz[u]=1;
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]) continue;
fa[v]=u,dpt[v]=dpt[u]+1;
dfs1(v);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
int w[MAXN],wcnt,tp[MAXN];
void dfs2(int u)
{
int v=son[u];if(!v) return;
tp[v]=tp[u],w[v]=++wcnt;
dfs2(v);
for(int i=G[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
tp[v]=v,w[v]=++wcnt;
dfs2(v);
}
}
const int rt=1;
struct SEGTN{int l,r,lc,rc;int sum,tag;} t[MAXN<<2];int tcnt=rt;
void upd(int o){t[o].sum=((ll)t[t[o].lc].sum+(ll)t[t[o].rc].sum)%P;}
void pushdown(int o)
{
SEGTN &lc=t[t[o].lc],&rc=t[t[o].rc];int &tag=t[o].tag;
if(tag)
{
lc.sum=((ll)lc.sum+((ll)(lc.r-lc.l+1)*(ll)tag)%P)%P,lc.tag=((ll)lc.tag+(ll)tag)%P;
rc.sum=((ll)rc.sum+((ll)(rc.r-rc.l+1)*(ll)tag)%P)%P,rc.tag=((ll)rc.tag+(ll)tag)%P;
tag=0;
}
}
void buildTree(int o,int l,int r)
{
t[o].l=l,t[o].r=r;
if(l==r) return;
int mid=(l+r)>>1;
t[o].lc=++tcnt;buildTree(t[o].lc,l,mid);
t[o].rc=++tcnt;buildTree(t[o].rc,mid+1,r);
upd(o);
}
void chDet(int o,int l,int r,int v)
{
if(l<=t[o].l&&t[o].r<=r)
{
t[o].sum=((ll)t[o].sum+((((ll)t[o].r-(ll)t[o].l+1LL)%P)*(ll)v)%P)%P;
t[o].tag=((ll)t[o].tag+(ll)v)%P;
return;
}
int mid=(t[o].l+t[o].r)>>1;
pushdown(o);
if(l<=mid) chDet(t[o].lc,l,r,v);
if(r>mid) chDet(t[o].rc,l,r,v);
upd(o);
}
int calSum(int o,int l,int r)
{
if(l<=t[o].l&&t[o].r<=r) return t[o].sum;
int mid=(t[o].l+t[o].r)>>1;
pushdown(o);
int res=0;
if(l<=mid) res=((ll)res+calSum(t[o].lc,l,r))%P;
if(r>mid) res=((ll)res+calSum(t[o].rc,l,r))%P;
return res;
}
void chainDet(int x,int y,int v)
{
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]y]]) std::swap(x,y);
chDet(rt,w[tp[x]],w[x],v);
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
chDet(rt,w[x],w[y],v);
}
int chainSum(int x,int y)
{
int res=0;
while(tp[x]!=tp[y])
{
if(dpt[tp[x]]y]]) std::swap(x,y);
res=((ll)res+calSum(rt,w[tp[x]],w[x]))%P;
x=fa[tp[x]];
}
if(w[x]>w[y]) std::swap(x,y);
res=((ll)res+calSum(rt,w[x],w[y]))%P;
return res;
}
struct PT{int id,p;bool isR;} p[MAXN<<1];int pcnt;
bool cmp(const PT &a,const PT &b){return a.pint high,low,z;} ans[MAXN];
int main()
{
int i;
scanf("%d%d",&N,&M);
int u,v;
for(i=2;i<=N;i++) scanf("%d",&v),addEdge2(i,v+1);
dpt[1]=1;dfs1(1);
tp[1]=1,w[1]=++wcnt;dfs2(1);
buildTree(rt,1,N);
for(i=1;i<=M;i++)
{
int l,r;scanf("%d%d%d",&l,&r,&ans[i].z);l++,r++,ans[i].z++;
p[++pcnt].id=i,p[pcnt].p=l-1,p[pcnt].isR=false;
p[++pcnt].id=i,p[pcnt].p=r,p[pcnt].isR=true;
}
std::sort(p+1,p+pcnt+1,cmp);
int now=0;
for(i=1;i<=pcnt;i++)
{
while(now1);
if(p[i].isR) ans[p[i].id].high=chainSum(rt,ans[p[i].id].z);
else ans[p[i].id].low=chainSum(rt,ans[p[i].id].z);
}
for(i=1;i<=M;i++) printf("%d\n",(ans[i].high-ans[i].low+P)%P);
return 0;
}