题目大意
一棵以 \(1\) 为根的 \(n(2\leq n\leq 10^5)\) 的树,每个节点 \(i\) 有权值 \(a_{i}(1\leq a_{i}\leq 10^6)\) ,求 \(\sum_{i=1}^{n}\sum_{j=i+1}^{n}[a_{i}\oplus a_{j}=a_{lca(i,j)}](i\oplus j)\) 。
思路
考虑 \(dsu\space on\space tree\) ,因为 \(a_{i}>0\) ,所以能够产生贡献的节点 \((i,j)\) 一定分属 \(lca(i,j)\) 两侧,于是计算各个子树的贡献时,考虑到对于每个节点 \(x\) ,对其中一棵子树中的节点 \(i\) , 其他子树中的每一个 \(a_{j} = a_{i}\oplus a_{x}\) 的节点 \(j\) 就会对答案产生 \(i\oplus j\) 的贡献。所以我们可以用一个数组 \(f[a,b,c]\) 来记录当权值为 \(a\) 时,该权值的节点编号在二进制中第 \(b\) 为的数字为 \(c\) 的节点个数,然后我们就可以对 \(i\) 来按位枚举有多少 \(j\) 能够对答案在这一位上产生贡献来计算答案。我们先计算所有轻子树内的答案,然后去掉轻子树对 \(f\) 的贡献,之后计算重子数的答案,之后保留其对 \(f\) 的贡献,再遍历所有子树,计算 \(f\) 以及跨子树的节点的答案,最后全部加起来即可。复杂度 \(O(nlogn)\) 。
代码
#include
#include
#include
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair PII;
#define all(x) x.begin(),x.end()
//#define int LL
//#define lc p*2+1
//#define rc p*2+2
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL MOD = 1000000007;
const LL mod = 998244353;
const int maxn = 100010;
int N, A[maxn];
vectorG[maxn];
int vsize[maxn], hson[maxn], L[maxn], R[maxn], rnk[maxn], tot = 0;
LL tmp;
int f[1 << 20][20][2];
void add_edge(int from, int to)
{
G[from].push_back(to);
G[to].push_back(from);
}
void add(int v, int t)
{
for (int i = 19; i >= 0; i--)
f[A[v]][i][(v >> i) & 1] += t;
}
void dfs(int v,int p)
{
hson[v] = 0;
L[v] = ++tot;
rnk[tot] = v;
vsize[v] = 1;
for (int i = 0; i < G[v].size(); i++)
{
int to = G[v][i];
if (to == p)
continue;
dfs(to, v);
vsize[v] += vsize[to];
if (!hson[v] || vsize[to] > vsize[hson[v]])
hson[v] = to;
}
R[v] = tot;
}
void dsu(int v, int p)
{
for (int i = 0; i < G[v].size(); i++)
{
int to = G[v][i];
if (to == p || to == hson[v])
continue;
dsu(to, v);//单个子树内的贡献
for (int j = L[to]; j <= R[to]; j++)
add(rnk[j], -1);//清空计数信息
}
if (hson[v])
dsu(hson[v], v);
for (int i = 0; i < G[v].size(); i++)
{
int to = G[v][i];
if (to == p || to == hson[v])
continue;
for (int j = L[to]; j <= R[to]; j++)
{
int tar = A[rnk[j]] ^ A[v];
for (int i = 19; i >= 0; i--)
tmp += (1LL << i) * (LL)f[tar][i][((rnk[j] >> i) & 1) ^ 1];
}
for (int j = L[to]; j <= R[to]; j++)
add(rnk[j], 1);
}
add(v, 1);//加上自己的计数信息
}
void solve()
{
dfs(1, 0), dsu(1, 0);
cout << tmp << endl;
}
int main()
{
IOS;
cin >> N;
for (int i = 1; i <= N; i++)
cin >> A[i];
int u, v;
for (int i = 1; i < N; i++)
{
cin >> u >> v;
add_edge(u, v);
}
solve();
return 0;
}