给定一棵有 n n n 个结点的树,点带权,求所有路径异或值之和。 ( n ≤ 1 0 5 ) (n \leq 10^5) (n≤105)
https://codeforces.com/problemset/problem/766/E
若是路径权值和,则树形 d p dp dp 就可,现在变成异或,按位考虑。异或值之和可拆位表示成各二进制位数量,树形 d p dp dp 时维护子树内各二进制位数量,则可以转移维护。
#include
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
vector<int> G[maxn];
int a[maxn], dp[maxn][20][2];
int n; ll ans;
void dfs(int u, int f){
for(int i = 0; i < 20; ++i){
++dp[u][i][(a[u] >> i) & 1];
}
for(auto &v : G[u]){
if(v == f) continue;
dfs(v, u);
for(int i = 0; i < 20; ++i){
ans += (1ll << i) * dp[v][i][0] * dp[u][i][1];
ans += (1ll << i) * dp[v][i][1] * dp[u][i][0];
}
for(int i = 0; i < 20; ++i){
int t = (a[u] >> i) & 1;
dp[u][i][0] += dp[v][i][t];
dp[u][i][1] += dp[v][i][t ^ 1];
}
}
}
int main(){
ios::sync_with_stdio(0); cin.tie(0);
cin >> n;
for(int i = 1; i <= n; ++i) cin >> a[i];
for(int i = 1; i < n; ++i){
int u, v; cin >> u >> v;
G[u].pb(v), G[v].pb(u);
}
dfs(1, 0);
for(int i = 1; i <= n; ++i) ans += a[i];
cout << ans << endl;
return 0;
}
这是为了练习点分 (脑抽) 写的:
#include
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
vector<int> G[maxn];
int a[maxn], siz[maxn], vis[maxn], dis[maxn], mp[20][2];
int n, tn, rmn, rt, tot;
ll ans;
void getRt(int u, int f){
int mx = 0; siz[u] = 1;
for(auto &v : G[u]){
if(v == f || vis[v]) continue;
getRt(v, u);
siz[u] += siz[v];
mx = max(mx, siz[v]);
}
mx = max(mx, tn - siz[u]);
if(mx < rmn) rmn = mx, rt = u;
}
void add(int x, int val){
for(int i = 0; i < 20; ++i){
mp[i][(x >> i) & 1] += val;
}
}
void dfs(int u, int f, int val){
dis[++tot] = val;
for(auto &v : G[u]){
if(v == f || vis[v]) continue;
dfs(v, u, val ^ a[v]);
}
}
void cal(int u){
add(0, 1);
for(auto &v : G[u]){
if(vis[v]) continue;
tot = 0;
dfs(v, u, a[v]);
for(int i = 1; i <= tot; ++i){
int tmp = dis[i] ^ a[u];
for(int j = 0; j < 20; ++j){
ans += (1ll << j) * mp[j][(~tmp >> j) & 1];
}
}
for(int i = 1; i <= tot; ++i){
add(dis[i], 1);
}
}
memset(mp, 0, sizeof mp);
}
void dfz(int u){
vis[u] = 1;
cal(u);
for(auto &v : G[u]){
if(vis[v]) continue;
tn = siz[v], rmn = inf, getRt(v, u);
dfz(rt);
}
vis[u] = 0;
}
int main(){
ios::sync_with_stdio(0); cin.tie(0);
cin >> n;
for(int i = 1; i <= n; ++i) cin >> a[i];
for(int i = 1; i < n; ++i){
int u, v; cin >> u >> v;
G[u].pb(v), G[v].pb(u);
}
tn = n, rmn = inf, getRt(1, 0);
dfz(rt);
for(int i = 1; i <= n; ++i) ans += a[i];
cout << ans << endl;
return 0;
}
这是为了练习换根 d p dp dp (脑抽) 写的:
#include
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
vector<int> G[maxn];
int a[maxn], dp[maxn][20][2], fp[maxn][20][2];
int n; ll ans;
void dfs1(int u, int f){
for(int i = 0; i < 20; ++i){
++dp[u][i][(a[u] >> i) & 1];
}
for(auto &v : G[u]){
if(v == f) continue;
dfs1(v, u);
for(int i = 0; i < 20; ++i){
int t = (a[u] >> i) & 1;
dp[u][i][0] += dp[v][i][t];
dp[u][i][1] += dp[v][i][t ^ 1];
}
}
}
void dfs2(int u, int f){
for(auto &v : G[u]){
if(v == f) continue;
for(int i = 0; i < 20; ++i){
int tu = (a[u] >> i) & 1, tv = (a[v] >> i) & 1;
fp[v][i][0] += fp[u][i][tv];
fp[v][i][1] += fp[u][i][tv ^ 1];
int tp[2] = {};
tp[0] = dp[u][i][0] - dp[v][i][tu] - (tu ^ 1);
tp[1] = dp[u][i][1] - dp[v][i][tu ^ 1] - tu;
fp[v][i][0] += tp[tv] + (tv ^ 1);
fp[v][i][1] += tp[tv ^ 1] + tv;
}
dfs2(v, u);
}
}
int main(){
ios::sync_with_stdio(0); cin.tie(0);
cin >> n;
for(int i = 1; i <= n; ++i) cin >> a[i];
for(int i = 1; i < n; ++i){
int u, v; cin >> u >> v;
G[u].pb(v), G[v].pb(u);
}
dfs1(1, 0);
for(int i = 0; i < 20; ++i){
++fp[1][i][(a[1] >> i) & 1];
}
dfs2(1, 0);
for(int i = 1; i <= n; ++i){
for(int j = 0; j < 20; ++j){
ans += (1ll << j) * (dp[i][j][1] + fp[i][j][1]);
}
}
ans >>= 1;
cout << ans << endl;
return 0;
}