[省选联考 2020 A 卷] 树

题目描述
给定一棵 n n n 个结点的有根树 T T T,结点从 1 1 1 开始编号,根结点为 1 1 1 号结点,每个结点有一个正整数权值 v i v_i vi

x x x 号结点的子树内(包含 x x x 自身)的所有结点编号为 c 1 , c 2 , … , c k c_1,c_2,\dots,c_k c1,c2,,ck,定义 x x x 的价值为:

v a l ( x ) = ( v c 1 + d ( c 1 , x ) ) ⊕ ( v c 2 + d ( c 2 , x ) ) ⊕ ⋯ ⊕ ( v c k + d ( c k , x ) ) val(x)=(v_{c_1}+d(c_1,x)) \oplus (v_{c_2}+d(c_2,x)) \oplus \cdots \oplus (v_{c_k}+d(c_k, x)) val(x)=(vc1+d(c1,x))(vc2+d(c2,x))(vck+d(ck,x))

其中 d ( x , y ) d(x,y) d(x,y) 表示树上 x x x 号结点与 y y y 号结点间唯一简单路径所包含的边数, d ( x , x ) = 0 d(x,x) = 0 d(x,x)=0 ⊕ \oplus 表示异或运算。

请你求出 ∑ i = 1 n v a l ( i ) \sum\limits_{i=1}^n val(i) i=1nval(i) 的结果。

输入格式
第一行一个正整数 n n n 表示树的大小。

第二行 n n n 个正整数表示 v i v_i vi

接下来一行 n − 1 n-1 n1 个正整数,依次表示 2 2 2 号结点到 n n n 号结点,每个结点的父亲编号 p i p_i pi

输出格式
仅一行一个整数表示答案。

输入输出样例

输入 #1
5
5 4 1 2 3
1 1 2 2
输出 #1
12

说明/提示
【样例解释 11】

v a l ( 1 ) = ( 5 + 0 ) ⊕ ( 4 + 1 ) ⊕ ( 1 + 1 ) ⊕ ( 2 + 2 ) ⊕ ( 3 + 2 ) = 3 val(1)=(5+0)\oplus(4+1)\oplus(1+1)\oplus(2+2)\oplus(3+2)=3 val(1)=(5+0)(4+1)(1+1)(2+2)(3+2)=3

v a l ( 2 ) = ( 4 + 0 ) ⊕ ( 2 + 1 ) ⊕ ( 3 + 1 ) = 3 val(2)=(4+0)\oplus(2+1)\oplus(3+1) = 3 val(2)=(4+0)(2+1)(3+1)=3

v a l ( 3 ) = ( 1 + 0 ) = 1 val(3)=(1+0)=1 val(3)=(1+0)=1

v a l ( 4 ) = ( 2 + 0 ) = 2 val(4)=(2+0)=2 val(4)=(2+0)=2

v a l ( 5 ) = ( 3 + 0 ) = 3 val(5)=(3+0)=3 val(5)=(3+0)=3

和为 12 12 12

10 % 10\% 10% 的数据: 1 ≤ n ≤ 2501 1\leq n\leq 2501 1n2501

40 % 40\% 40% 的数据: 1 ≤ n ≤ 152501 1\leq n\leq 152501 1n152501

另有 20 % 20\% 20% 的数据:所有 p i = i − 1   ( 2 ≤ i ≤ n ) p_i=i-1 \ (2\leq i\leq n) pi=i1 (2in)

另有 20 % 20\% 20% 的数据:所有 v i = 1   ( 1 ≤ i ≤ n ) v_i=1 \ (1\leq i\leq n) vi=1 (1in)

100 % 100\% 100% 的数据: 1 ≤ n , v i ≤ 525010 , 1 ≤ p i ≤ n 1\leq n,v_i \leq 525010,1\leq p_i\leq n 1n,vi525010,1pin

对每一个节点建维护异或和的 t r i e trie trie,然后从叶子到根合并 t r i e trie trie,当节点 x x x的所有子节点的 t r i e trie trie合并到 x x x后,将 x x x t r i e trie trie所有值加 1 1 1,然后将 v [ x ] v[x] v[x]插入节点 x x x t r i e trie trie

#include

#define si(a) scanf("%d",&a)
#define sl(a) scanf("%lld",&a)
#define sd(a) scanf("%lf",&a)
#define sc(a) scahf("%c",&a);
#define ss(a) scanf("%s",a)
#define pi(a) printf("%d\n",a)
#define pl(a) printf("%lld\n",a)
#define pc(a) putchar(a)
#define ms(a) memset(a,0,sizeof(a))
#define repi(i, a, b) for(register int i=a;i<=b;++i)
#define repd(i, a, b) for(register int i=a;i>=b;--i)
#define reps(s) for(register int i=head[s];i;i=Next[i])
#define ll long long
#define ull unsigned long long
#define vi vector
#define pii pair
#define mii unordered_map
#define msi unordered_map
#define lowbit(x) ((x)&(-(x)))
#define ce(i, r) i==r?'\n':' '
#define pb push_back
#define fi first
#define se second
#define all(x) x.begin(),x.end()
#define INF 0x3f3f3f3f
#define pr(x) cout<<#x<<": "<<x<<endl
using namespace std;

inline int qr() {
    int f = 0, fu = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-')fu = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        f = (f << 3) + (f << 1) + c - 48;
        c = getchar();
    }
    return f * fu;
}

const int N = 6e5 + 10;

struct Trie {
    const int H = 21;
    int ch[N * 25][2], w[N * 25], xorv[N * 25], rt[N];
    int tot = 0;

    inline int build() {
        ++tot;
        ch[tot][0] = ch[tot][1] = w[tot] = xorv[tot] = 0;
        return tot;
    }

    inline void update(int p) {
        w[p] = xorv[p] = 0;
        if (ch[p][0]) {
            w[p] += w[ch[p][0]];
            xorv[p] ^= xorv[ch[p][0]] << 1;
        }
        if (ch[p][1]) {
            w[p] += w[ch[p][1]];
            xorv[p] ^= (xorv[ch[p][1]] << 1) | w[ch[p][1]];
        }
        w[p] &= 1;
    }

    void insert(int &p, int d, int x) {
        if (!p)p = build();
        if (d > H) {
            w[p]++;
            return;
        }
        insert(ch[p][x & 1], d + 1, x >> 1);
        update(p);
    }

    void erase(int p, int d, int x) {
        if (d > H) {
            w[p]--;
            return;
        }
        erase(ch[p][x & 1], d + 1, x >> 1);
        update(p);
    }

    void addall(int p) {
        swap(ch[p][0], ch[p][1]);
        if (ch[p][0])addall(ch[p][0]);
        update(p);
    }

    int merge(int p, int q) {
        if (!p)return q;
        if (!q)return p;
        w[p] = w[p] + w[q] & 1, xorv[p] ^= xorv[q];
        ch[p][0] = merge(ch[p][0], ch[q][0]);
        ch[p][1] = merge(ch[p][1], ch[q][1]);
        return p;
    }

} tr;

int head[N], ver[N << 1], Next[N << 1], tot;
int n, v[N];
ll ans;

inline void add(int x, int y) {
    ver[++tot] = y;
    Next[tot] = head[x];
    head[x] = tot;
}

void dfs(int x, int f) {
    reps(x) {
        int y = ver[i];
        if (y == f)continue;
        dfs(y, x);
        tr.rt[x] = tr.merge(tr.rt[x], tr.rt[y]);

    }
    tr.addall(tr.rt[x]);
    tr.insert(tr.rt[x], 0, v[x]);
    ans += tr.xorv[tr.rt[x]];
}

int main() {
    n = qr();
    repi(i, 1, n)v[i] = qr();
    repi(i, 2, n) {
        int f = qr();
        add(i, f), add(f, i);
    }
    dfs(1, 0);
    pl(ans);
    return 0;
}

你可能感兴趣的:(trie树,ACM)