树统计 (dfs序+线段树)

时间限制: 1 Sec  内存限制: 128 MB

题目描述

然而,这一切宛如一度揉过的复写纸,无不同原来有着少许然而却是无可挽回的差异。—— 村上春树

关于树的算法有一大堆,样样都是毒瘤。

比如说 2019 CSP-S 的树论题,如果擅长树形数据结构马上想到正解,但是 3edc2wsx1qaz 并不擅长,就只好骗分了。

3edc2wsx1qaz 当时数组开小了,惨遭 RE,3edc2wsx1qaz 一想起这事,不禁夙夜忧叹,辗转反侧。

现在他又遇到一道毒瘤的树上问题了,他下定决心:这次一定要写出正解!

题目是这样的:

有一颗有n个点的树,每条边有一个权值ai。树的根节点为1号节点。定义一对点对(u,v)的距离dist(u,v)为在u到v的简单路径上的所有边的边权的异或。

你需要进行q次操作,操作分为两种:
1.将x点与它父亲所连的边的边权异或w。
2.询问以节点y为根的子树中所有点对的距离之和,答案对998244353取模
也就是说,对于每次 2 操作,设以节点y为根的子树的节点集合为subtrss(y), 你需要求出以下式子的值:
树统计 (dfs序+线段树)_第1张图片

 

输入

为了方便你获取部分分,我们会告诉你测试点编号。

第一行输入三个正整数n,q,r(2≤n≤10^5,2≤q≤10^5,1≤r≤50),表示树的节点数,操作数,该测试点编号。

接下来n-1行每行三个正整数u,v,w,表示有一条连接u,v,权值为w的边。(1≤u≤n,1≤v≤n,0≤w<2^10)

接下来q行,每行开头输入一个数opt(opt=1 或opt=2 ),表示操作类型。

若opt=1,则再输入两个数x,w(1
若opt=2 ,则再输入一个数y,表示一次询问,你需要输出以节点y为根的子树中所有点对的距离之和。

输出

输出若干行,对于每次 2 操作,输出一个正整数,表示答案。

样例输入 Copy

8 8 0
2 1 0
3 1 0
4 3 0
5 2 1
6 5 1
7 5 0
8 1 0
1 4 0
2 7
1 3 0
2 5
1 5 1
1 4 0
1 5 0
2 1

样例输出 Copy

0
4
14

提示

样例解释:
由于这组数据为样例,所以r=0。
保证测试数据中1≤r≤50
树统计 (dfs序+线段树)_第2张图片

树统计 (dfs序+线段树)_第3张图片

 

对于一颗子树内的任意两点(x,y)之间距离(如题所述的距离)为dis(1,x)^dis(1,y)

那么我们知道统计子树按位拆开后对应二进制为1的个数即可。

对于更新操作,修改边权为w,枚举二进制位,当且仅当w对应的二进制为1时必发生当前点及子树对应位0和1个数的互换,用线段树维护即可。

 

/**/
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

typedef long long LL;
using namespace std;

const long long mod = 998244353;
const int maxn = 1e5 + 5;

int n, q, r, tot, cnt;
int head[maxn], dfn[maxn], sz[maxn], son[maxn], top[maxn], f[maxn], id[maxn], w[maxn];
int tr[11][maxn << 2], lzy[11][maxn << 2];

struct node
{
	int v, w, next;
}a[maxn << 1];

void dfs(int x, int pre){
	f[x] = pre;
	sz[x] = 1;
	for (int i = head[x]; i != -1; i = a[i].next){
		int v = a[i].v;
		if(v == pre) continue;
		w[v] = w[x] ^ a[i].w;
		dfs(v, x);
		sz[x] += sz[v];
		if(sz[son[x]] < sz[v]) son[x] = v;
	}
}

void dfs1(int x, int topf){
	top[x] = topf;
	dfn[x] = ++cnt;
	id[cnt] = x;
	if(son[x]) dfs1(son[x], topf);
	for (int i = head[x]; i != -1; i = a[i].next){
		int v = a[i].v;
		if(v == f[x] || v == son[x]) continue;
		dfs1(v, v);
	}
}

void up(int rt){
	for (int i = 0; i < 10; i++){
		tr[i][rt] = tr[i][rt << 1] + tr[i][rt << 1 | 1];
	}
}

void up(int rt, int i){
	tr[i][rt] = tr[i][rt << 1] + tr[i][rt << 1 | 1];
}

void down(int rt, int l, int r){
	int mid = (l + r) >> 1;
	for (int i = 0; i < 10; i++){
		if(lzy[i][rt]){
			lzy[i][rt << 1] ^= 1;
			lzy[i][rt << 1 | 1] ^= 1;
			tr[i][rt << 1] = (mid - l + 1) - tr[i][rt << 1];
			tr[i][rt << 1 | 1] = (r - mid) - tr[i][rt << 1 | 1];
			lzy[i][rt] = 0;
		}
	}
}

void build(int rt, int l, int r){
	if(l == r){
		for (int i = 0; i < 10; i++){
			if(1 << i & w[id[l]]) tr[i][rt] = 1;
			else tr[i][rt] = 0;
		}
		return ;
	}
	int mid = (l + r) >> 1;
	build(rt << 1, l, mid);
	build(rt << 1 | 1, mid + 1, r);
	up(rt);
}

void update(int rt, int l, int r, int L, int R, int i){
	if(L <= l && r <= R){
		tr[i][rt] = r - l + 1 - tr[i][rt];
		lzy[i][rt] ^= 1;
		return ;
	}
	down(rt, l, r);
	int mid = (l + r) >> 1;
	if(mid >= L) update(rt << 1, l, mid, L, R, i);
	if(mid < R) update(rt << 1 | 1, mid + 1, r, L, R, i);
	up(rt, i);
}

int query(int rt, int l, int r, int L, int R, int i){
	if(L <= l && r <= R) return tr[i][rt];
	down(rt, l, r);
	int mid = (l + r) >> 1, ans = 0;
	if(mid >= L) ans += query(rt << 1, l, mid, L, R, i);
	if(mid < R) ans += query(rt << 1 | 1, mid + 1, r, L, R, i);
	return ans;
}

void modify(int x, int W){
	for (int i = 0; i < 10; i++){
		if(W >> i & 1) update(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, i);
	}
	w[x] ^= W;
}

LL sum(int x){
	LL ans = 0;
	for (int i = 0; i < 10; i++){
		int num = query(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, i);
		ans = (ans + 1LL * num * (sz[x] - num) % mod * (1 << i) % mod);
	}
	return ans;
}

int main()
{
	//freopen("in.txt", "r", stdin);
	//freopen("out.txt", "w", stdout);

	memset(head, -1, sizeof(head));
	scanf("%d %d %d", &n, &q, &r);
	for (int i = 1, u, v, w; i < n; i++){
		scanf("%d %d %d", &u, &v, &w);
		a[tot] = node{v, w, head[u]}, head[u] = tot++;
		a[tot] = node{u, w, head[v]}, head[v] = tot++;
	}
	dfs(1, 0);
	dfs1(1, 1);
	build(1, 1, n);
	for (int i = 1, op, x, y, w; i <= q; i++){
		scanf("%d", &op);
		if(op == 1){
			scanf("%d %d", &x, &w);
			modify(x, w);
		}else{
			scanf("%d", &y);
			printf("%lld\n", (sum(y) << 1) % mod);
		}
	}

	return 0;
}
/**/

 

你可能感兴趣的:(树链剖分)