另有 dsu on tree \texttt{dsu on tree} dsu on tree的解法。
我写的是虚树的做法。对于每种颜色建虚树,然后在虚树上求一次子树size,然后每个虚树上的点对原树上对应点做初步贡献。
最后再 dfs \texttt{dfs} dfs一遍求答案。
存答案我用的是结构体,存最大值和编号之和。可以简单合并。
考虑证明上面的做法对于一种颜色,在初步贡献和最后的一遍 dfs \texttt{dfs} dfs不会重复计算。
因为虚树上的点要么是关键点的 l c a lca lca,要么就是关键点。也就是说虚树点一定满足要么自己为关键点,要么有多个子树都有关键点。
那么虚树点的 s i z siz siz一定大于每一个子树的 s i z siz siz(原树上的子树)。(这里的 s i z siz siz指的是子树内当前颜色出现次数 也就是 子树内关键点的数量)
本质上就是虚树上不会出现只有一个儿子的非关键点。
那么合并一定不会有问题。
时间复杂度 O ( n log n ) O(n\log n) O(nlogn)
#include
using namespace std;
template<class T>inline void read(T &res) {
char ch; while(!isdigit(ch=getchar()));
for(res=ch-'0';isdigit(ch=getchar());res=res*10+ch-'0');
}
#define pb push_back
#define pii pair
typedef long long LL;
const int MAXN = 100005;
vector<int>e[MAXN], g[MAXN], vec[MAXN];
int n, col[MAXN], dfn[MAXN], tmr, dep[MAXN], fa[MAXN], sz[MAXN], son[MAXN], top[MAXN];
void dfs1(int u, int ff) {
dep[u] = dep[fa[u] = ff] + (sz[u] = 1);
for(auto v : e[u])
if(v != ff) {
dfs1(v, u); sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int tp) {
top[u] = tp; dfn[u] = ++tmr;
if(son[u]) dfs2(son[u], tp);
for(auto v : e[u])
if(v != fa[u] && v != son[u]) dfs2(v, v);
}
inline int Lca(int u, int v) {
while(top[u] != top[v]) {
if(dep[top[u]] > dep[top[v]]) u = fa[top[u]];
else v = fa[top[v]];
}
return dep[u] > dep[v] ? v : u;
}
inline bool cmp(int i, int j) { return dfn[i] < dfn[j]; }
bool flg[MAXN];
int siz[MAXN], stk[MAXN], indx;
void ins(int x) {
if(x == stk[indx]) return;
if(!indx) { stk[++indx] = x; return; }
int lca = Lca(x, stk[indx]);
if(lca == stk[indx]) { stk[++indx] = x; return; }
while(indx>1 && dfn[stk[indx-1]] >= dfn[lca]) g[stk[indx-1]].pb(stk[indx]), --indx;
if(lca != stk[indx]) g[lca].pb(stk[indx]), stk[indx] = lca;
stk[++indx] = x;
}
struct node {
int mx; LL sum;
node(int mx=0, LL sum=0):mx(mx), sum(sum){}
inline node operator +(const node &o)const {
return mx > o.mx ? *this : mx < o.mx ? o : node(mx, sum + o.sum);
}
}f[MAXN];
void dfs(int u, int clr) {
siz[u] = flg[u];
for(auto v : g[u]) dfs(v, clr), siz[u] += siz[v];
f[u] = f[u] + node(siz[u], clr); g[u].clear();
}
void getans(int u, int ff) {
for(auto v : e[u]) if(v != ff)
getans(v, u), f[u] = f[u] + f[v];
}
int main () {
read(n);
for(int i = 1; i <= n; ++i) read(col[i]), vec[col[i]].pb(i);
for(int i = 1, x, y; i < n; ++i) read(x), read(y), e[x].pb(y), e[y].pb(x);
dfs1(1, 0), dfs2(1, 1);
for(int i = 1; i <= n; ++i) if(vec[i].size()) {
sort(vec[i].begin(), vec[i].end(), cmp);
indx = 0;
for(auto x : vec[i]) flg[x] = 1, ins(x);
while(indx > 1) g[stk[indx-1]].pb(stk[indx]), --indx;
dfs(stk[1], i);
for(auto x : vec[i]) flg[x] = 0;
}
getans(1, 0);
for(int i = 1; i <= n; ++i) printf("%lld ", f[i].sum);
}