题目链接:https://nanti.jisuanke.com/t/39277
每条合法(异或为0)的路径的贡献为路径两端点数的乘积。
当点分治确定root后,需要求出以root为根的各节点的siz。
然后按照路径两端点是否在同一颗子树上进行合并。
#include
#include
#include
#define rep(i, a, b) for(int i = (a); i <= (b); i++)
#define pb push_back
#define ll long long
using namespace std;
using namespace __gnu_pbds;
const int N = 1e5+1000;
const int mod = 1e9+7;
int n,k;
struct node {
int v,nxt;
ll w;
}edge[2*N];
int tot,head[N];
void ae(int u,int v,ll w) {
edge[++tot] = node{v,head[u],w};
head[u] = tot;
}
void init(int n) {
tot = 0;
rep(i, 1, n) head[i] = -1;
}
int siz[N],Root,wt[N],Tsiz,fa[N],siz1[N];
bool vis[N];
ll ans;
gp_hash_table num,Tnum;
void GetRoot(int u,int f,bool flag) {
siz[u] = 1;
wt[u] = 0;
if(flag) fa[u] = f;
for(int i = head[u]; ~i ; i = edge[i].nxt) {
int v = edge[i].v;
if(v==f||vis[v]) continue;
GetRoot(v,u,flag);
siz[u] += siz[v];
wt[u] = max(wt[u],siz[v]);
}
wt[u] = max(wt[u],Tsiz-siz[u]);
if(wt[Root]>wt[u]) Root = u;
}
void dfs(int u,int f,ll dis) {
if(fa[u] == f) num[dis] += siz1[u]; //点分治过程中的动态siz
else num[dis] += (n-siz1[f]); //如果顺方向,那么就是siz1[u],否则就是n-siz1[f].
num[dis] %= mod;
for(int i = head[u]; ~i ; i = edge[i].nxt) {
int v = edge[i].v;
ll w = edge[i].w;
if(v==f||vis[v]) continue;
dfs(v,u,dis^w);
}
}
void calc(int u) {
Tnum.clear();
for(int i = head[u]; ~i ; i = edge[i].nxt) {
int v = edge[i].v;
ll w = edge[i].w;
if(vis[v]) continue;
num.clear();
dfs(v,u,w);
for(auto x:num) //不同子树之间的路径合并
ans = (ans + num[x.first]*Tnum[x.first]%mod)%mod;
if(fa[v] == u) //以u为起点的路径单独计算
ans = (ans + num[0]*(n-siz1[v])%mod)%mod;
else
ans = (ans + num[0]*siz1[u]%mod)%mod;
for(auto x:num) Tnum[x.first] = (Tnum[x.first]+x.second)%mod;
}
}
void divide(int u) {
calc(u);
vis[u] = 1;
for(int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].v;
if(vis[v]) continue;
Root = 0,Tsiz = siz[v];
GetRoot(v,0,0);
divide(Root);
}
}
int main() {
freopen("a.txt","r",stdin);
ios::sync_with_stdio(0);
cin>>n;
init(n);
rep(i, 2, n) {
int u,v;
ll w;
cin>>u>>w;
v = i;
ae(u,v,w);
ae(v,u,w);
}
rep(i, 1, n) vis[i] = 0;
wt[0] = 1e9,Tsiz = n,GetRoot(1,0,1),ans = 0;
rep(i, 1, n) siz1[i] = siz[i];
divide(Root);
cout<