【bzoj1036】[ZJOI2008]树的统计Count 树链剖分+线段树

Description

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 III. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。 对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

Sample Input

4

1 2

2 3

4 1

4 2 1 3

12

QMAX 3 4

QMAX 3 3

QMAX 3 2

QMAX 2 3

QSUM 3 4

QSUM 2 1

CHANGE 1 5

QMAX 3 4

CHANGE 3 6

QMAX 3 4

QMAX 2 4

QSUM 3 4

Sample Output

4

1

2

2

10

6

5

6

5

16

HINT

Source

树的分治

树链剖分裸题。noip前作死学新东西,简直233

我们需要不知道多少个数组:

deep[u]:u点深度
sz[u]u的子树大小
son[u]:u的重儿子(重儿子定义:儿子vsz[v]最大的儿子)
fa[u]:u的爹
inseg[u]:树中的点u在线段树中的标号。
intr[u]:线段树的点u在树中的标号。
top[u]u所在重链的顶端节点。

树中非叶子节点向自己的重儿子连一条边,轻儿子分别以自己为重链的顶端,分别下拉重链。这样就把书分成好多链。

然后dfs给各个链上的点编号,这样能保证树中同一条链上的点编号连续,一个点和他父亲编号连续,这样就可以映射到线段树(或其他东西…)里区间维护。

贴代码:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int size=1000010;
const int INF=233333333;
int head[size],nxt[size],to[size],tot=0;

void build(int f,int t)
{
    to[++tot]=t;
    nxt[tot]=head[f];
    head[f]=tot;
}

int deep[size],sz[size],son[size],fa[size];
int inseg[size],intr[size],top[size];

int num[size];
void dfs_1(int u,int f)
{
    fa[u]=f;
    deep[u]=deep[f]+1;
    sz[u]=1;
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==f) continue;
        dfs_1(v,u);
        sz[u]+=sz[v];
        if(!son[u]||sz[v]>sz[son[u]]) son[u]=v;
    }
}

int totp=0;
void dfs_2(int u,int topu)
{
    top[u]=topu;
    inseg[u]=++totp;
    intr[totp]=u;
    if(!son[u]) return ;
    dfs_2(son[u],topu);
    for(int i=head[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==son[u]||v==fa[u]) continue;
        dfs_2(v,v);
    }
}

struct segment{
    int l,r;
    int sum,maxx;
}tree[size];

void update(int p)
{
    tree[p].sum=tree[p<<1].sum+tree[p<<1|1].sum;
    tree[p].maxx=max(tree[p<<1].maxx,tree[p<<1|1].maxx);
}
void build(int p,int l,int r)
{
    tree[p].l=l;    tree[p].r=r;
    if(l==r)
    {
        tree[p].sum=tree[p].maxx=num[intr[l]];
        return ;
    }
    int mid=(l+r)>>1;
    build(p<<1,l,mid); 
    build(p<<1|1,mid+1,r);
    update(p); 
}

void change(int p,int x,int d)
{
    if(tree[p].l==tree[p].r)
    {
        tree[p].sum=tree[p].maxx=d;
        return ;
    }
    int mid=(tree[p].l+tree[p].r)>>1;
    if(x<=mid) change(p<<1,x,d); 
    else change(p<<1|1,x,d);
    update(p); 
}

int ask_sum(int p,int l,int r)
{
    if(l<=tree[p].l&&tree[p].r<=r)
    {
        return tree[p].sum;
    }
    int mid=(tree[p].l+tree[p].r)>>1;
    int ans=0;
    if(l<=mid) ans+=ask_sum(p<<1,l,r); 
    if(mid<r) ans+=ask_sum(p<<1|1,l,r); 
    return ans; 
}
int ask_max(int p,int l,int r)
{
    if(l<=tree[p].l&&tree[p].r<=r)
    {
        return tree[p].maxx;
    }
    int mid=(tree[p].l+tree[p].r)>>1;
    int ans=-INF;
    if(l<=mid) ans=max(ans,ask_max(p<<1,l,r)); 
    if(mid<r) ans=max(ans,ask_max(p<<1|1,l,r)); 
    return ans; 
}


int find_sum(int x,int y)
{
    int fx=top[x],fy=top[y];
    int ans=0;
    while(fx!=fy)
    {
        if(deep[fx]<deep[fy]) swap(x,y),swap(fx,fy);
        ans+=ask_sum(1,inseg[fx],inseg[x]);
        x=fa[fx]; fx=top[x];
    }
    if(deep[x]>deep[y]) swap(x,y);
    ans+=ask_sum(1,inseg[x],inseg[y]);
    return ans;
}

int find_max(int x,int y)
{
    int fx=top[x],fy=top[y];
    int ans=-INF;
    while(fx!=fy)
    {
        if(deep[fx]<deep[fy]) swap(x,y),swap(fx,fy);
        ans=max(ans,ask_max(1,inseg[fx],inseg[x]));
        x=fa[fx]; fx=top[x];
    }
    if(deep[x]>deep[y]) swap(x,y);
    ans=max(ans,ask_max(1,inseg[x],inseg[y]));
    return ans;
}

char s[233];
int main()
{

    int n;
    scanf("%d",&n);
    for(int i=1;i<=n-1;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        build(a,b);
        build(b,a);
    }
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&num[i]);
    }
    dfs_1(1,0); dfs_2(1,1);

    build(1,1,n);

    int q;
    scanf("%d",&q);
    while(q--)
    {
        int x,y;
        scanf("%s%d%d",s,&x,&y);
        if(s[0]=='C') change(1,inseg[x],y),num[x]=y;
        else if(s[1]=='M') printf("%d\n",find_max(x,y));
        else printf("%d\n",find_sum(x,y));
    }
    return 0;
}


你可能感兴趣的:(ZJOI)