树统计(虚树)

时间限制: 1 Sec  内存限制: 128 MB

题目描述

骗分过样例,暴力出奇迹。
关于树的算法有一大堆,样样都是毒瘤。
比如说 NOIP2018 提高组的 D2T3,如果会动态 DP 的做法那么就马上想到正解,但是 Tweetuzki 不会动态 DP,就只好骗分了。
可惜树题的码量也是超级大的。听说好多学长都会动态 DP,但是考场上调不出来,只好暴力分收场了。疯狂暗示
Tweetuzki 当时暴力写挂了,有 4 个点写成了死循环……于是分数白白少了 16 分。Tweetuzki 一想起这事,不禁夙夜忧叹,辗转反侧。
现在他又遇到一道毒瘤树上问题了,他下定决心:这次一定要把暴力分写满!
题目是这样的:
有一棵 n 个点的树,边有边权,每个点有颜色 ci。求所有颜色不同的点对的距离之和。由于答案可能很大,你只需要输出其对 998,244,353 取模的结果即可。
形式化地讲,记 u 号点和 v 号点在树上的距离为 dist(u,v),求:

树统计(虚树)_第1张图片

输入

输入文件将会遵循以下格式:
n type
c1 c2 ⋯ cn
u1 v1 w1
u2 v2 w2

un−1 vn−1 wn−1
第一行两个正整数 n,type(2≤n≤2×105,1≤type≤6),其中 n 表示点数,type为部分分类型,详见数据范围,type=0 表示样例数据。
第二行输入 n 个正整数 ci(1≤ci≤109),表示每个点的颜色。
接下来n−1 行,每行输入三个正整数 ui,vi,wi(1≤ui

输出

输出一行一个非负整数,表示答案对 998,244,353 取模的结果。

样例输入 Copy

4 0
1 2 3 3
1 2 5
2 3 4
3 4 7

样例输出 Copy

90

提示

满足条件的点对有 (1,2),(1,3),(1,4),(2,1),(2,3),(2,4),(3,1),(3,2),(4,1),(4,2),故答案为 5+9+16+5+4+11+9+4+16+11=90。

Subtask #1:n≤300, type=1。
Subtask #2:n≤2 000, type≤2。
Subtask #3:n≤10 000, type≤3。
Subtask #4:对于第 i (1≤i≤n) 号点,ci=i。type=4。
Subtask #5 :对于第 i(1≤i Subtask #6:无特殊性质,type≤6。

 

题目要求不同颜色顶点间的距离和,我们转化为所有顶点间的距离和-相同颜色点间的距离和

对于所有顶点间的距离和,我们跑一遍图,求出每条边左右的顶点对数即可求出每条边的贡献,最终得到所有边的贡献

将颜色相同的顶点分别建立一棵虚树,每一颗虚树类似上面跑一遍图即可

最终答案<<1即可

/**/
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

typedef long long LL;
using namespace std;

const long long mod = 998244353;
const int maxn = 200005;

int n, type, tot, cnt, top, len;
int c[maxn], b[maxn];
int head[maxn], sz[maxn], son[maxn], topf[maxn], f[maxn], dep[maxn], dfn[maxn];
LL ans, res, dis[maxn];
int e[maxn], s[maxn], dp[maxn];
bool vis[maxn];

vector v[maxn];
vector > g[maxn];

struct node
{
	int v, w, next;
}a[maxn << 1];

bool cmp(int x, int y){
	return dfn[x] < dfn[y];
}

void dfs(int x, int pre){
	sz[x] = 1;
	dep[x] = dep[pre] + 1;
	f[x] = pre;
	for (int i = head[x]; i != -1; i = a[i].next){
		int v = a[i].v;
		if(v == pre) continue;
		dis[v] = (dis[x] + a[i].w) % mod;
		dfs(v, x);
		ans = (ans + 1LL * sz[v] * (n - sz[v]) % mod * a[i].w % mod) % mod;
		sz[x] += sz[v];
		if(sz[son[x]] < sz[v]) son[x] = v;
	}
}

void dfs1(int x, int topfa){
	topf[x] = topfa;
	dfn[x] = ++cnt;
	if(!son[x]) return ;
	dfs1(son[x], topfa);
	for (int i = head[x]; i != -1; i = a[i].next){
		int v = a[i].v;
		if(topf[v]) continue;
		dfs1(v, v);
	}
}

int LCA(int x, int y){
	while(topf[x] != topf[y]){
		if(dep[topf[x]] < dep[topf[y]]) swap(x, y);
		x = f[topf[x]];
	}
	if(dep[x] > dep[y]) swap(x, y);
	return x;
}

void add_edge(int u, int v){
	if(u == n + 1) g[u].emplace_back(make_pair(v, 0));
	else g[u].emplace_back(make_pair(v, (dis[v] - dis[u] + mod) % mod));
}

void insert(int x){
	if(top <= 1){
		s[++top] = x;
		return ;
	}
	int lca = LCA(s[top], x);
	if(lca == s[top]){
		s[++top] = x;
		return ;
	}
	while(top > 1 && dfn[lca] <= dfn[s[top - 1]]){
		add_edge(s[top - 1], s[top]);
		top--;
	}
	if(lca != s[top]) add_edge(lca, s[top]), s[top] = lca;
	s[++top] = x;
}

void dfs2(int u){
	dp[u] = vis[u];
	for (auto x : g[u]){
		int v = x.first;
		LL w = x.second;
		dfs2(v);
		dp[u] += dp[v];
		res = (res + 1LL * dp[v] * (len - dp[v]) % mod * w % mod) % mod;
	}
	g[u].clear();
}

int main()
{
	//freopen("in.txt", "r", stdin);
	//freopen("out.txt", "w", stdout);

	memset(head, -1, sizeof(head));
	scanf("%d %d", &n, &type);
	for (int i = 1; i <= n; i++) scanf("%d", &c[i]), b[i] = c[i];
	sort(b + 1, b + 1 + n);
	int num = unique(b + 1, b + 1 + n) - b - 1;
	for (int i = 1; i <= n; i++) c[i] = lower_bound(b + 1, b + 1 + num, c[i]) - b;
	for (int i = 1; i <= n; i++) v[c[i]].emplace_back(i);
	for (int i = 1, u, v, w; i < n; i++){
		scanf("%d %d %d", &u, &v, &w);
		a[tot] = node{v, w, head[u]}, head[u] = tot++;
		a[tot] = node{u, w, head[v]}, head[v] = tot++;
	}
	dfs(1, 0);
	dfs1(1, 1);
	for (int i = 1; i <= num; i++){
		if(v[i].empty()) continue;
		len = v[i].size();
		for (int j = 0; j < len; j++) e[j + 1] = v[i][j], vis[e[j + 1]] = true;
		sort(e + 1, e + 1 + len, cmp);
		s[top = 1] = n + 1;
		for (int j = 1; j <= len; j++) insert(e[j]);
		while(top > 1) add_edge(s[top - 1], s[top]), top--;
		res = 0;
		dfs2(n + 1);
		for (int j = 1; j <= len; j++) vis[e[j]] = false;
		ans = (ans - res + mod) % mod;
	}
	printf("%lld\n", (ans << 1) % mod);

	return 0;
}
/*
8 3
1 2 3 1 3 3 1 2
1 2 1
2 4 2
2 5 2
5 6 3
5 7 3
1 3 4
3 8 4
*/

 

你可能感兴趣的:(虚树)