题目链接
可以发现,每一个点的答案可以分为两个部分:
第一个部分的答案非常好求:
假设遍历到一个点 x x x,设这个点的颜色是 c [ x ] c[x] c[x],那么包含分治中心到点 x x x这段子路径的路径一定包含 c [ x ] c[x] c[x],如果从分治中心遍历到点 x x x的过程中第一次遍历到 c [ x ] c[x] c[x],那么在分治中心的子树中,以分治中心为端点向下的路径中,有 s z [ x ] sz[x] sz[x]个路径包含 c [ x ] c[x] c[x],其中 s z [ x ] sz[x] sz[x]为点 x x x为根节点的子树的大小。点 x x x的子树的每一个点,都在 c [ x ] c[x] c[x]颜色上为分治中心提供了 1 1 1的贡献。
这样不断向下dfs的过程中,不断记录哪些颜色已经遍历过,然后统计答案贡献。
第二个部分的答案稍微繁琐一点:
经过分治中心的情况,可以分为:
不过这两种情况如果细细想来会发现,第二种情况可以通过上面的节点的第一种情况求得,所以这里对于每一个分治中心只考虑第一种情况。
由于我们只考虑分治中心点的子树(也就是当前的分治块),所以我们可以先通过一次dfs统计出整个分治块的点数、包含的各个颜色的数量。
我们依旧遍历到点 x x x,但是这次我们时刻更新的是我们当前遍历的点 x x x的答案(很难把这个点的贡献算到其他的所有点的答案上去,所以我们反过来思考)。
现在考虑点 x x x可以获得的答案贡献。我们可以先处理出其他子树的点数和各个颜色的数量。从分治中心到其他子树上的点可以得到的颜色,从点 x x x跑到分治中心在跑到那些点上,也可以得到那些颜色。设其他子树上第 i i i个颜色有 c n t [ i ] cnt[i] cnt[i]个,那么对于其他子树上每一个存在的颜色,都可以给点 x x x提供 c n t [ i ] cnt[i] cnt[i]的贡献。
然后考虑点 x x x到分治中心这条路径上的颜色,在这条路径上出现过意味着其他子树的所有点到点 x x x的路径都会包含这个颜色,所以这个颜色会直接给点 x x x提供其他子树的总大小的贡献。有一些颜色已经在其他子树上出现过,此时之前已经算过其他子树上这个颜色给点 x x x带来的贡献 c n t [ i ] cnt[i] cnt[i]了,所以在加上其他子树的总大小的时候,要把之前加过的 c n t [ i ] cnt[i] cnt[i]减去。
剩下的就是纯纯的点分治了。
#include
using namespace std;
typedef long long LL;
const int N = 100005;
LL n, root, cnt_sum, siz, en;
LL front[N], c[N], mxp[N], sz[N], sum[N], vis[N];
LL cnt[N], ccnt[N], pcnt[N], mk[N], tag[N];
//mxp[i]:第i个点的最大子树大小
//sz[i]:以第i个点为根节点的子树的大小
//sum[i]:第i个点的答案
//vis[i]:第i个点是否已经搜过
//cnt[i]:分治中心的所有子树的颜色信息
//ccnt[i]:用来备份cnt数组
//pcnt[i]:用于统计当前子树的颜色信息,然后从cnt中抠掉
//mk[i]:标记当前搜索的子树中有第i个颜色,用于优化时间复杂度
//tag[i]:当前所遍历的点已经算过了第i个颜色的答案
vector<int> col, col2;
//col:存放当前子树有什么颜色,后续遍历颜色时只遍历这里面的颜色以节省时间
//col2:col的备份
struct Edge {
LL v, next;
}e[N * 4];
void addEdge(int u, int v) {
e[++en] = {v, front[u]};
front[u] = en;
}
//获取分治块的重心
void get_root(int u, int f, int total) {
mxp[u] = 0;
sz[u] = 1;
for (int i = front[u]; i; i = e[i].next) {
int v = e[i].v;
if (vis[v] or v == f) continue;
get_root(v, u, total);
sz[u] += sz[v];
mxp[u] = max(mxp[u], sz[v]);
}
mxp[u] = max(mxp[u], total - sz[u]);
if (!root or mxp[u] < mxp[root]) {
root = u;
}
}
//获取当前分治块的各个点的子树大小
void get_sz(int u, int f) {
sz[u] = 1;
for (int i = front[u]; i; i = e[i].next) {
int v = e[i].v;
if (v == f or vis[v]) continue;
get_sz(v, u);
sz[u] += sz[v];
}
}
//dfs获取颜色信息,CNT是传进来的数组
void dfs(int u, int f, LL* CNT) {
//之前没有遇到过c[u]这个颜色
if (!mk[c[u]]) {
mk[c[u]] = 1;
col.push_back(c[u]);
}
//c[u]这个颜色没有记过贡献
if (tag[c[u]] == 0) CNT[c[u]] += sz[u];
//标记这个颜色已经在整个子树记录过贡献了
//之后不用再统计这个颜色了
++tag[c[u]];
for (int i = front[u]; i; i = e[i].next) {
int v = e[i].v;
if (v == f or vis[v]) continue;
dfs(v, u, CNT);
}
--tag[c[u]];//递归结束,撤销标记
}
void fix(int u, int f, LL lst) {
LL tmp = lst;
//第一次遇到这个颜色
//如果其他子树中有c[u],那么cnt[c[u]]为正,把原来加过的减掉
//其他子树中没有c[u]时,会有cnt[c[u]] = 0
if (tag[c[u]] == 0) tmp += siz - cnt[c[u]];
++tag[c[u]];
sum[u] += (cnt_sum + tmp); //更新这个点的贡献,cnt_sum是其他子树的那些颜色的总贡献
for (int i = front[u]; i; i = e[i].next) {
int v = e[i].v;
if (v == f or vis[v]) continue;
fix(v, u, tmp);
}
--tag[c[u]];
}
void calc(int u) {
cnt_sum = 0;
get_sz(u, 0);
col.clear();
dfs(u, 0, cnt);//考虑点u为端点,直接向下dfs获得答案
for (int x: col) {
mk[x] = 0;
}//清空标记数组
for (int x: col) {
cnt_sum += cnt[x];
ccnt[x] = cnt[x];
}
col2 = col;
sum[u] += cnt_sum;
LL tmp = cnt_sum;
for (int i = front[u]; i; i = e[i].next) {
int v = e[i].v;
if (vis[v]) continue;
tag[c[u]] = 1;
col.clear();
dfs(v, u, pcnt);//获取当前子树的颜色信息
tag[c[u]] = 0;//清空标记
for (int x: col) mk[x] = 0;//清空标记
cnt[c[u]] -= sz[v];//减去点u的颜色贡献
cnt_sum -= sz[v];//减去点u的颜色贡献
for (int x: col) {
cnt[x] -= pcnt[x];//减去这个子树内的颜色贡献
cnt_sum -= pcnt[x];//减去这个子树内的颜色贡献
}
siz = sz[u] - sz[v];
fix(v, u, 0);//统计分治中心到点u的路径的颜色答案
//恢复信息
cnt[c[u]] += sz[v];
cnt_sum = tmp;
for (int x: col) {
cnt[x] = ccnt[x];
pcnt[x] = 0;
}
}
//清空数组
for (int x: col2) cnt[x] = 0;
col2.clear();
//标记为已搜索
vis[u] = 1;
}
void get_ans(int u) {
calc(u);
for (int i = front[u]; i; i = e[i].next) {
int v = e[i].v;
if (vis[v]) continue;
root = 0;
get_root(v, 0, sz[v]);//找重心
get_ans(root);//分治搜索
}
}
void main2() {
cin >> n;
en = 0;
for (int i = 1; i <= n; ++i) {
front[i] = mk[i] = vis[i] = tag[i] = sum[i] = 0;
cin >> c[i];
}
for (int i = 1; i < n; ++i) {
int x, y;
cin >> x >> y;
addEdge(x, y);
addEdge(y, x);
}
mxp[0] = n;
root = 0;
get_root(1, 0, n);
get_ans(root);
for (int i = 1; i <= n; ++i) {
cout << sum[i] << '\n';
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
LL _ = 1;
// cin >> _;
while (_--) main2();
return 0;
}