这是一道树上启发式的题目,也算是我做得树上启发式的第一题,这次比赛一下出现了两题树上启发式,都不会,让人无语。
看到xor,马上拆位,这个都是老套路了。
对于一个点u,把它和它的子树的字符串全部一次丢到一个trie里,因为已经拆了位,随便维护一下就可以求出答案。
但是我们知道这样暴力是很慢的,复杂度易得。
利用启发式可以把暴力的复杂度优化。
大概流程如下:
对于一个点x,先计算其轻儿子的答案,做完以后,全局的trie是清空的。
接着计算其重儿子的答案,注意此时重儿子的所有子树的节点的字符串是在trie里。
接着再把x的除重儿子子树外的所有节点加入trie,x也要加进trie里。
这样x的子树的答案就已经计算出来了。
如果x是其父亲节点的重儿子,保留当前的trie,否则清空trie。
我们可以分析一下复杂度:
一个点被加入trie的次数就是它到根的路径上的轻重链数的总和,有加入才有删除。
所以复杂度是: O(|S|∗log n)
Code:
#include
#include
#define ll long long
#define fo(i, x, y) for(int i = x; i <= y; i ++)
using namespace std;
const int N = 500005;
char s[N], S[N];
int a2[17], w;
int n, x, y, a[N][17], l[N], r[N];
int final[N], tot;
struct edge {
int to, next;
}e[N * 2];
int fa[N], bz[N], siz[N], son[N];
int next[N][26], tt = 1, root = 1;
ll ans[N], ans2[N], sm[N][2];
void link(int x, int y) {
e[++ tot].next = final[x], e[tot].to = y, final[x] = tot;
e[++ tot].next = final[y], e[tot].to = x, final[y] = tot;
}
void dg(int x) {
bz[x] = 1;
siz[x] = 1;
for(int i = final[x]; i; i = e[i].next) {
int y = e[i].to; if(bz[y]) continue;
fa[y] = x; dg(y); siz[x] += siz[y];
son[x] = siz[y] > siz[son[x]] ? y : son[x];
}
bz[x] = 0;
}
ll Insert(int x) {
int y = root; ll sum = 0;
fo(i, l[x], r[x]) {
int c = s[i] - 'a';
if(next[y][c] == 0) next[y][c] = ++ tt;
y = next[y][c];
sum += sm[y][!a[x][w]] * a2[w];
sm[y][a[x][w]] ++;
}
return sum;
}
ll dfs(int x) {
bz[x] = 1;
ll sum = Insert(x);
for(int i = final[x]; i; i = e[i].next) {
int y = e[i].to; if(bz[y]) continue;
sum += dfs(y);
}
bz[x] = 0;
return sum;
}
void dg2(int x) {
bz[x] = 1;
for(int i = final[x]; i; i = e[i].next) {
int y = e[i].to; if(bz[y] || y == son[x]) continue;
dg2(y);
}
if(son[x] != 0) dg2(son[x]), ans[x] += ans[son[x]];
for(int i = final[x]; i; i = e[i].next) {
int y = e[i].to; if(bz[y] || y == son[x]) continue;
ans[x] += dfs(y);
}
ans[x] += Insert(x);
if(x != son[fa[x]]) {
fo(i, 1, tt) memset(next[i], 0, sizeof(next[i])), sm[i][0] = sm[i][1] = 0;
tt = 1;
}
bz[x] = 0;
}
int main() {
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
scanf("%d", &n);
fo(i, 1, n) {
scanf("%d", &x);
fo(j, 0, 16) a[i][j] = x & 1, x /= 2;
}
fo(i, 1, n) {
scanf("%s", S + 1); int len = strlen(S + 1);
l[i] = r[i - 1] + 1; r[i] = r[i - 1] + len;
fo(j, l[i], r[i]) s[j] = S[j - l[i] + 1];
}
fo(i, 1, n - 1) {
scanf("%d %d", &x, &y);
link(x, y);
}
dg(1);
a2[0] = 1; fo(i, 1, 16) a2[i] = a2[i - 1] * 2;
for(w = 0; w <= 16; w ++) {
dg2(1);
fo(i, 1, n) ans2[i] += ans[i], ans[i] = 0;
}
fo(i, 1, n) printf("%lld\n", ans2[i]);
}