题目描述
给定一棵 n n n 个结点的有根树 T T T,结点从 1 1 1 开始编号,根结点为 1 1 1 号结点,每个结点有一个正整数权值 v i v_i vi。
设 x x x 号结点的子树内(包含 x x x 自身)的所有结点编号为 c 1 , c 2 , … , c k c_1,c_2,\dots,c_k c1,c2,…,ck,定义 x x x 的价值为:
v a l ( x ) = ( v c 1 + d ( c 1 , x ) ) ⊕ ( v c 2 + d ( c 2 , x ) ) ⊕ ⋯ ⊕ ( v c k + d ( c k , x ) ) val(x)=(v_{c_1}+d(c_1,x)) \oplus (v_{c_2}+d(c_2,x)) \oplus \cdots \oplus (v_{c_k}+d(c_k, x)) val(x)=(vc1+d(c1,x))⊕(vc2+d(c2,x))⊕⋯⊕(vck+d(ck,x))
其中 d ( x , y ) d(x,y) d(x,y) 表示树上 x x x 号结点与 y y y 号结点间唯一简单路径所包含的边数, d ( x , x ) = 0 d(x,x) = 0 d(x,x)=0。 ⊕ \oplus ⊕ 表示异或运算。
请你求出 ∑ i = 1 n v a l ( i ) \sum\limits_{i=1}^n val(i) i=1∑nval(i) 的结果。
输入格式
第一行一个正整数 n n n 表示树的大小。
第二行 n n n 个正整数表示 v i v_i vi。
接下来一行 n − 1 n-1 n−1 个正整数,依次表示 2 2 2 号结点到 n n n 号结点,每个结点的父亲编号 p i p_i pi。
输出格式
仅一行一个整数表示答案。
输入输出样例
输入 #1
5
5 4 1 2 3
1 1 2 2
输出 #1
12
说明/提示
【样例解释 11】
v a l ( 1 ) = ( 5 + 0 ) ⊕ ( 4 + 1 ) ⊕ ( 1 + 1 ) ⊕ ( 2 + 2 ) ⊕ ( 3 + 2 ) = 3 val(1)=(5+0)\oplus(4+1)\oplus(1+1)\oplus(2+2)\oplus(3+2)=3 val(1)=(5+0)⊕(4+1)⊕(1+1)⊕(2+2)⊕(3+2)=3。
v a l ( 2 ) = ( 4 + 0 ) ⊕ ( 2 + 1 ) ⊕ ( 3 + 1 ) = 3 val(2)=(4+0)\oplus(2+1)\oplus(3+1) = 3 val(2)=(4+0)⊕(2+1)⊕(3+1)=3。
v a l ( 3 ) = ( 1 + 0 ) = 1 val(3)=(1+0)=1 val(3)=(1+0)=1。
v a l ( 4 ) = ( 2 + 0 ) = 2 val(4)=(2+0)=2 val(4)=(2+0)=2。
v a l ( 5 ) = ( 3 + 0 ) = 3 val(5)=(3+0)=3 val(5)=(3+0)=3。
和为 12 12 12。
10 % 10\% 10% 的数据: 1 ≤ n ≤ 2501 1\leq n\leq 2501 1≤n≤2501。
40 % 40\% 40% 的数据: 1 ≤ n ≤ 152501 1\leq n\leq 152501 1≤n≤152501。
另有 20 % 20\% 20% 的数据:所有 p i = i − 1 ( 2 ≤ i ≤ n ) p_i=i-1 \ (2\leq i\leq n) pi=i−1 (2≤i≤n)。
另有 20 % 20\% 20% 的数据:所有 v i = 1 ( 1 ≤ i ≤ n ) v_i=1 \ (1\leq i\leq n) vi=1 (1≤i≤n)。
100 % 100\% 100% 的数据: 1 ≤ n , v i ≤ 525010 , 1 ≤ p i ≤ n 1\leq n,v_i \leq 525010,1\leq p_i\leq n 1≤n,vi≤525010,1≤pi≤n。
对每一个节点建维护异或和的 t r i e trie trie,然后从叶子到根合并 t r i e trie trie,当节点 x x x的所有子节点的 t r i e trie trie合并到 x x x后,将 x x x的 t r i e trie trie所有值加 1 1 1,然后将 v [ x ] v[x] v[x]插入节点 x x x的 t r i e trie trie。
#include
#define si(a) scanf("%d",&a)
#define sl(a) scanf("%lld",&a)
#define sd(a) scanf("%lf",&a)
#define sc(a) scahf("%c",&a);
#define ss(a) scanf("%s",a)
#define pi(a) printf("%d\n",a)
#define pl(a) printf("%lld\n",a)
#define pc(a) putchar(a)
#define ms(a) memset(a,0,sizeof(a))
#define repi(i, a, b) for(register int i=a;i<=b;++i)
#define repd(i, a, b) for(register int i=a;i>=b;--i)
#define reps(s) for(register int i=head[s];i;i=Next[i])
#define ll long long
#define ull unsigned long long
#define vi vector
#define pii pair
#define mii unordered_map
#define msi unordered_map
#define lowbit(x) ((x)&(-(x)))
#define ce(i, r) i==r?'\n':' '
#define pb push_back
#define fi first
#define se second
#define all(x) x.begin(),x.end()
#define INF 0x3f3f3f3f
#define pr(x) cout<<#x<<": "<<x<<endl
using namespace std;
inline int qr() {
int f = 0, fu = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-')fu = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
f = (f << 3) + (f << 1) + c - 48;
c = getchar();
}
return f * fu;
}
const int N = 6e5 + 10;
struct Trie {
const int H = 21;
int ch[N * 25][2], w[N * 25], xorv[N * 25], rt[N];
int tot = 0;
inline int build() {
++tot;
ch[tot][0] = ch[tot][1] = w[tot] = xorv[tot] = 0;
return tot;
}
inline void update(int p) {
w[p] = xorv[p] = 0;
if (ch[p][0]) {
w[p] += w[ch[p][0]];
xorv[p] ^= xorv[ch[p][0]] << 1;
}
if (ch[p][1]) {
w[p] += w[ch[p][1]];
xorv[p] ^= (xorv[ch[p][1]] << 1) | w[ch[p][1]];
}
w[p] &= 1;
}
void insert(int &p, int d, int x) {
if (!p)p = build();
if (d > H) {
w[p]++;
return;
}
insert(ch[p][x & 1], d + 1, x >> 1);
update(p);
}
void erase(int p, int d, int x) {
if (d > H) {
w[p]--;
return;
}
erase(ch[p][x & 1], d + 1, x >> 1);
update(p);
}
void addall(int p) {
swap(ch[p][0], ch[p][1]);
if (ch[p][0])addall(ch[p][0]);
update(p);
}
int merge(int p, int q) {
if (!p)return q;
if (!q)return p;
w[p] = w[p] + w[q] & 1, xorv[p] ^= xorv[q];
ch[p][0] = merge(ch[p][0], ch[q][0]);
ch[p][1] = merge(ch[p][1], ch[q][1]);
return p;
}
} tr;
int head[N], ver[N << 1], Next[N << 1], tot;
int n, v[N];
ll ans;
inline void add(int x, int y) {
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}
void dfs(int x, int f) {
reps(x) {
int y = ver[i];
if (y == f)continue;
dfs(y, x);
tr.rt[x] = tr.merge(tr.rt[x], tr.rt[y]);
}
tr.addall(tr.rt[x]);
tr.insert(tr.rt[x], 0, v[x]);
ans += tr.xorv[tr.rt[x]];
}
int main() {
n = qr();
repi(i, 1, n)v[i] = qr();
repi(i, 2, n) {
int f = qr();
add(i, f), add(f, i);
}
dfs(1, 0);
pl(ans);
return 0;
}