bzoj1036-树链剖分模板


   剖分后的树有如下性质:
    性质1:如果(v,u)为轻边,则siz[u] * 2< siz[v];
    性质2:从根到某一点的路径上轻链、重链的个数都不大于logn。


总之两个dfs

dfs1(int x,int f)   //f是x的父亲

枚举和x相邻的点的时候注意不等于f 才可以递归

  要维护的东西:

    dep[x] x节点深度

    siz[x] 以x节点为根的子树的大小

    son[x] 以x节点为根的子树中重儿子的编号

    fa[x] x节点的父亲

dfs2(int x,int tp)

   为x节点编号(线段树中的下标),并且求出每个点的top[x]

   这个时候优先向x的重儿子也就是son[x]向下递归,之后再对x的其他儿子递归,但是tp的值变成了其他儿子本身,不再是top[x]

建线段树

    注:如果题目中给的是边的权值,就变成深度较深的孩子的点权(之前对于一条边要看是否需要根据深度交换两个端点)

    和普通的线段树的建树,修改,询问都一样

最关键的部分:

     如何在不同的链之间移动?

     记f1 =top[u],f2 = top[v]。
     当f1<> f2时:不妨设dep[f1] >=dep[f2],那么就更新u到f1的父边的权值(logn),并使u = fa[f1]。
     当f1 =f2时:u与v在同一条重链上,若u与v不是同一点,就更新u到v路径上的边的权值(logn),否则修改完成;
     重复上述过程,直到修改完成。

inline int find(int va, int vb)
{
     int f1 = top[va], f2 = top[vb], tmp = 0;
     while (f1 != f2)
     {
           if (dep[f1] < dep[f2])
           { swap(f1, f2); swap(va, vb); }
           tmp = max(tmp, maxi(1, 1, z, w[f1], w[va]));
           va = fa[f1]; f1 = top[va];
     }
     if (va == vb) return tmp;
     if (dep[va] > dep[vb]) swap(va, vb);
     return max(tmp, maxi(1, 1, z, w[son[va]], w[vb]));  // maxi 是线段树的query
}


一开始没有看到负数,不开心

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
using namespace std;
#define MAXN 30010
int head[MAXN];
struct node{
	int y,next;
}edge[2*MAXN];
int id[MAXN],son[MAXN],dep[MAXN],fa[MAXN],siz[MAXN],a[MAXN];
int l,x,y,n,q,tot;
char s[10];
int pre[MAXN],top[MAXN];
void add(int x,int y)
{
	l++;
	edge[l].y=y;
	edge[l].next=head[x];
	head[x]=l;
}
void dfs1(int x,int f)
{
	int y;
	fa[x]=f;
	son[x]=0;//the heaviest son
	siz[x]=1;
	for (int i=head[x];i!=-1;i=edge[i].next)
		if (edge[i].y!=f)
		{
			y=edge[i].y;
			dep[y]=dep[x]+1;
			dfs1(y,x);
			siz[x]+=siz[y];
			if (siz[son[x]]<siz[y])
				son[x]=y;
		}
}
void dfs2(int x,int tp)
{
	int y;
	top[x]=tp;
	id[x]=++tot;// tot is the total number
	pre[id[x]]=x;
	if (son[x]) dfs2(son[x],tp);
	for (int i=head[x];i!=-1;i=edge[i].next)
	 if (edge[i].y!=fa[x] && edge[i].y!=son[x])
	 {
		 y=edge[i].y;
		 dfs2(y,y);
	 }
}
struct point{
	int l,r,sum,max;
}tr[4*MAXN];
void updata(int p)
{
	tr[p].sum=tr[p<<1].sum+tr[p<<1|1].sum;
	tr[p].max=max(tr[p<<1].max,tr[p<<1|1].max);
}
void build(int p,int l,int r)
{
	tr[p].l=l;tr[p].r=r;
	if (l==r) {tr[p].sum=tr[p].max=a[pre[l]]; return ;}// find the backforward l in a[] to build the tree
	int mid=(l+r) >> 1;
	build(p<<1, l, mid);build(p<<1|1,mid+1,r);
	updata(p);
}
void change(int p,int x,int y)
{
	if (tr[p].l==x && tr[p].r==x)
	{
		tr[p].sum=tr[p].max=y;
		return ;
	}
	int mid=(tr[p].l+tr[p].r) >> 1;
	if (x<=mid) change(p<<1, x, y);
	if (x>mid) change(p<<1|1, x, y);
    updata(p);
}
int ask_max(int p,int l,int r)
{
	if (tr[p].l==l && tr[p].r==r)
	    return tr[p].max;
	int mid=(tr[p].l+tr[p].r) >> 1;
	if (r<=mid) return ask_max(p<<1, l, r);
	if (l>mid) return ask_max(p<<1|1, l, r);
	if (l<=mid && r>mid)
	{
		int s1=ask_max(p<<1, l, mid);
		int s2=ask_max(p<<1|1, mid+1, r);
		return max(s1,s2);
	}
}
int ask_sum(int p,int l,int r)
{
	if (tr[p].l==l && tr[p].r==r)
		return tr[p].sum;
	int mid=(tr[p].l+tr[p].r) >> 1;
	if (r<=mid) return ask_sum(p<<1, l, r);
	if (l>mid) return ask_sum(p<<1|1, l, r);
	if (l<=mid && r>mid)
	{
		int s1=ask_sum(p<<1, l, mid);
		int s2=ask_sum(p<<1|1, mid+1, r);
		return s1+s2;
	}
}
int find_max(int x, int y)
{
	int f1=top[x],f2=top[y],tmp=-0x3f3f3f;//有负数
	while (f1!=f2)
	{
		if (dep[f1]<dep[f2])
		{swap(f1,f2);swap(x,y);}
		tmp=max(tmp, ask_max(1,id[f1],id[x]));
		x=fa[f1];f1=top[x];
	}
	if (x==y) return max(tmp,ask_max(1,id[x],id[x]));
	if (dep[x]>dep[y]) swap(x,y);
	return max(tmp, ask_max(1, id[x], id[y]));
}
int find_sum(int x,int y)
{
	int f1=top[x], f2=top[y], tmp=0;
	while (f1!=f2)
	{
		if (dep[f1]<dep[f2])
		{ swap(f1,f2);swap(x,y);}
		tmp+=ask_sum(1,id[f1],id[x]);
		x=fa[f1];f1=top[x];
	}
	if (x==y) return tmp+ask_sum(1,id[x],id[y]);
	if (dep[x]>dep[y]) swap(x,y);
	return tmp+ask_sum(1,id[x], id[y]);
}
int main()
{
	scanf("%d", &n);
	memset(head,-1,sizeof(head));
	for (int i=1;i<n;i++)
	{
		scanf("%d%d", &x, &y);
		add(x,y);
		add(y,x);
	}
	for (int i=1;i<=n;i++) scanf("%d", &a[i]);
	dfs1(1,0);
	dfs2(1,1);
	build(1,1,n);
	scanf("%d", &q);
	while (q--)
	{
		scanf("%s%d%d", s, &x, &y);
		if (s[0]=='C')  change(1,id[x],y); // use id[x] !!
        if (s[0]=='Q' && s[1]=='M') printf("%d\n", find_max(x,y));//x and y are two points
		if (s[0]=='Q' && s[1]=='S') printf("%d\n", find_sum(x,y));//if there is X types of queries, we need to write X parts for the queries
	}
}


你可能感兴趣的:(bzoj1036-树链剖分模板)