1036: [ZJOI2008]树的统计Count树链剖分

  虽然是水题一枚,但是在我的不懈努下,交了80几发,在突然有一天重写的情况下,竟然1a了。

#include<iostream>

#include<cstdio>

#include<cstring>

#include<map>

#include<vector>

#include<stdlib.h>

using namespace std;



const int maxn = 33333;

struct Node

{

    int next; int to;

}e[maxn*2];

int len, z;

#define lson l,mid,rt<<1

#define rson mid+1,r,rt<<1|1



int head[maxn], top[maxn], son[maxn], father[maxn], size[maxn], deep[maxn], pos[maxn], val[maxn], vis[maxn];

int sum[maxn<<2],Max[maxn<<2];



void add(int from, int to)

{

    e[len].to = to;

    e[len].next = head[from];

    head[from] = len++;

}



void init(int x)

{

    size[x] = 1; son[x] = 0;

    for (int i = head[x]; ~i; i = e[i].next){

        int cc = e[i].to;

        if (cc == father[x]) continue;

        deep[cc] = deep[x] + 1; father[cc] = x;

        init(cc);

        size[x] += size[cc];

        if (size[son[x]] < size[cc]) son[x] = cc;

    }

}



void dfs(int x, int tp)

{

    pos[x] = ++z; vis[z] = x; top[x] = tp;

    if(son[x]) dfs(son[x],tp);

    for (int i = head[x]; i != -1; i = e[i].next){

        int cc = e[i].to;

        if (cc == father[x] || cc == son[x]) continue;

        dfs(cc, cc);

    }

}



void up(int rt)

{

    sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];

    Max[rt] = max(Max[rt << 1], Max[rt << 1 | 1]);

}

void build(int l, int r, int rt)

{

    if (l == r){

        Max[rt] = sum[rt] = val[vis[l]]; return;

    }

    int mid = (l + r) >> 1;

    build(lson);

    build(rson);

    up(rt);

}



void update(int key, int ans, int l, int r, int rt)

{

    if (l == r){

        Max[rt] = sum[rt] = ans; return;

    }

    int mid = (l + r) >> 1;

    if (key <= mid) update(key, ans, lson);

    else update(key, ans, rson);

    up(rt);

}



int ask_max(int L, int R, int l, int r, int rt)

{

    if (L <= l&&r <= R) return Max[rt];

    int ans = -1e9;

    int mid = (l + r) >> 1;

    if (L <= mid) ans = max(ans, ask_max(L, R, lson));

    if (R > mid) ans = max(ans, ask_max(L, R, rson));

    return ans;

}



int ask_sum(int L, int R, int l, int r, int rt)

{

    if (L <= l&&r <= R) return sum[rt];

    int ans = 0;

    int mid = (l + r) >> 1;

    if (L <= mid) ans += ask_sum(L, R, lson);

    if (R > mid) ans += ask_sum(L, R, rson);

    return ans;

}



void gao_max(int x, int y)

{

    int ans = -1e9;

    int fx = top[x]; int fy = top[y];

    while (fx != fy){

        if (deep[fx] < deep[fy]){

            swap(x, y); swap(fx, fy);

        }

        ans = max(ans, ask_max(pos[fx], pos[x], 1, z, 1));

        x = father[fx]; fx = top[x];

    }

    if (deep[x]>deep[y]){

        swap(x, y);

    }

    ans = max(ans, ask_max(pos[x], pos[y], 1, z, 1));

    printf("%d\n", ans);

}



void gao_sum(int x, int y)

{

    int ans = 0;

    int fx = top[x]; int fy = top[y];

    while (fx != fy){

        if (deep[fx] < deep[fy]){

            swap(x, y); swap(fx, fy);

        }

        ans += ask_sum(pos[fx], pos[x], 1, z, 1);

        x = father[fx]; fx = top[x];

    }

    if (deep[x]>deep[y]) swap(x, y);

    ans += ask_sum(pos[x], pos[y], 1, z, 1);

    printf("%d\n", ans);

}





int main()

{

    int n;

    int a, b, c;

    int q;

    char str[1000];

    while (cin >> n){

        len = z = 0;

        memset(head, -1, sizeof(head));

        for (int i = 0; i < n - 1; i++){

            cin >> a >> b; add(a, b); add(b, a);

        }

        for (int i = 1; i <= n; i++){

            scanf("%d", &val[i]);

        }

        init(1); dfs(1, 1); build(1, z, 1);

        cin >> q;

        while (q--){

            scanf("%s%d%d", str, &a, &b);

            if (strcmp(str, "QMAX") == 0){

                gao_max(a, b);

            }

            else if (strcmp(str, "QSUM") == 0){

                gao_sum(a, b);

            }

            else update(pos[a], b, 1, z, 1);

        }

    }

    return 0;

}

 

你可能感兴趣的:(count)