2020CCPC长春F题——dsu on tree+二进制拆分

题目链接

https://codeforces.com/gym/102832/problem/F

题意:

给你一颗n个节点的树,让你求出

\sum_{i=1}^{n}\sum_{j=i+1}^{n} [a_{i}\oplus a_{j}==a_{lca(i,j)}](i\oplus j)

的值。

题解:

对于这种子树查询的问题,常用的方法也就那几种,要么就是树链剖分后用数据结构维护,要么就是dsu on tree。

不难发现这个题用dsu on tree比较好写,即考虑以每个节点为根(即lca)的子树,然后枚举不同儿子之间对答案的贡献。

dsu on tree常见的套路是处理子树内的查询,而对于这种子树间的查询就需要对dsu on tree的思想有比较深的理解。

关于dsu on tree的思想以及处理子树内的查询操作可以参考我的另一篇博客:dsu on tree(树上启发式合并)算法总结+习题。

理解之后我们再来看这个题。

前面已经给出了这个题的大概思路,下面就需要完善细节。

假设我们当前子树的根节点为u。

那么我们就只需要找出不在u的同一个儿子节点的子树中的两个点i和j(即需要在u的不同儿子节点形成的子树中),且满足a[i]^a[j]=a[u],对答案的贡献就是i^j.

因为在dsu on tree我们最后只有重儿子的数据会保存下来,因此我们可以每次都把重儿子作为初始的第一个子树,只有一个子树是不会对答案产生贡献的,因为我们每次处理的是子树间的答案。

然后不断的枚举u的其他儿子与之前儿子之间产生的贡献,计算答案就行了。

但是这样还可能会重复,因为在枚举u的当前儿子的子树节点时,会计算在它之前且同属于u的同一个儿子的子树节点他们之间的贡献。

对于这一部分我们可以在枚举u的儿子节点时先dfs计算(u的儿子节点的子树内所有节点)和(之前已经枚举过的u的儿子节点(以及重儿子)的子树内所有节点)产生的贡献,

然后再dfs把u的当前儿子节点对后面儿子节点的子树节点的影响加上。

主要思路以及想出来了,下面就是怎么维护儿子节点之间的影响了。

第一个直接的想法是暴力,用vector v[a[i]]数组来维护值为a[i]的所有节点下标。

然后在计算答案,以及计算影响时直接暴力就行了。

这里用vector维护交上去是不会超时,但是用set或者map都会超时(可能因为每次插入和删除都是一段一段的)。

可能因为符合的点比较少,所以vector没被卡,实际上可以构造极端数据把这种假的暴力做法卡掉。

现在我们考虑优化,因为是异或运算,只有当前位一个为0一个为1才会对答案产生贡献,因此我们可以把每个数a[i]都拆成一个20位的二进制数。

然后我们用一个数组mp[a[i]][i][0/1]维护值为a[i],当前位i为0/1的数的个数。

在维护影响时我们维护这个mp数组即可,计算答案是,直接统计mp[a[u]^a[lca]]产生的贡献。

这样时间复杂度大概是O(nlognlogn)

 

vector暴力的代码

#include 
#define PI atan(1.0)*4
#define rp(i,s,t) for (register int i = (s); i <= (t); i++)
#define RP(i,t,s) for (register int i = (t); i >= (s); i--)
#define ll long long
#define ull unsigned long long
#define mst(a,b) memset(a,b,sizeof(a))
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pii pair
#define pll pair
#define pil pair
#define m_p make_pair
#define p_b push_back
#define debug puts("ac")
#define INF 0x3f3f3f3f
#define LINF 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
	int s=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		s=s*10+ch-'0';
		ch=getchar();
	}
	return s*f;
}
inline ll lread(){
	ll s=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		s=s*10+ch-'0';
		ch=getchar();
	}
	return s*f;
}
const int N = 1e5+7;
const int M = 1e6+1e5+7;
const int maxm=2e5+7;
int hson[N],sz[N],flag;
int mp[M][21][2];
int a[N],n;
ll ans=0;
vector  v[M];
set s[M];
struct edge {
	int head[N], to[maxm], nex[maxm], tot;
	void init() {
		tot = 0;
		rp(i,0,n) head[i] = -1;
	}
	void addedge(int u, int v) {
		to[tot] = v;
		nex[tot] = head[u];
		head[u] = tot++;
	}
}E;
void dfs1(int u,int f){
	sz[u]=1;
	for(int i=E.head[u];~i;i=E.nex[i]){
		int v=E.to[i];
		if(v==f) continue;
		dfs1(v,u);
		sz[u]+=sz[v];
		if(sz[hson[u]]>i)&1]+=k;
	if(k==1){
		v[a[u]].push_back(u);
		// s[a[u]].insert(u);
	}
	else{
		vector::iterator it=v[a[u]].begin();
		for(;it!=v[a[u]].end();it++){
			if(*it==u){
				v[a[u]].erase(it);
				break;
			}
		}
		// s[a[u]].erase(u);
	}
	for(int i=E.head[u];~i;i=E.nex[i]){
		int v=E.to[i];
		if(flag!=v&&v!=fa)
			add(v,u,k);
	}
}
void calc(int u,int f,int lca,int k){
	ll val=(a[u]^a[lca]);
	if(k==1){
		// rp(i,0,20) ans+=mp[val][i][!((u>>i)&1)]*(1ll<>i)&1]+=k;
		if(k==1){
			v[a[u]].push_back(u);
			// s[a[u]].insert(u);
		}
		else{
			vector::iterator it=v[a[u]].begin();
			for(;it!=v[a[u]].end();it++){
				if(*it==u){
					v[a[u]].erase(it);
					break;
				}
			}
			// s[a[u]].erase(u);
		}
	}
	for(int i=E.head[u];~i;i=E.nex[i]){
		int v=E.to[i];
		if(flag==v||v==f) continue;
		calc(v,u,lca,k);
		if(u==lca) add(v,u,k);
	}
}
void dfs(int u,int f,int keep){
	for(int i=E.head[u];~i;i=E.nex[i]){
		int v=E.to[i];
		if(v==f||v==hson[u]) continue;
		dfs(v,u,0);
	}
	if(hson[u]) dfs(hson[u],u,1),flag=hson[u];
	calc(u,f,u,1);
	if(hson[u]) flag=0;
	if(!keep) calc(u,f,u,-1);
}
int main(){
	n=read();
	E.init();
	rp(i,1,n) a[i]=read();
	rp(i,1,n-1){
		int u=read(),v=read();
		E.addedge(u,v);
		E.addedge(v,u);
	}
	dfs1(1,0);
	dfs(1,0,1);
	cout<

二进制优化的代码

#include 
#define PI atan(1.0)*4
#define rp(i,s,t) for (register int i = (s); i <= (t); i++)
#define RP(i,t,s) for (register int i = (t); i >= (s); i--)
#define ll long long
#define ull unsigned long long
#define mst(a,b) memset(a,b,sizeof(a))
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pii pair
#define pll pair
#define pil pair
#define m_p make_pair
#define p_b push_back
#define debug puts("ac")
#define INF 0x3f3f3f3f
#define LINF 0x3f3f3f3f3f3f3f3f
using namespace std;
inline int read(){
	int s=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		s=s*10+ch-'0';
		ch=getchar();
	}
	return s*f;
}
inline ll lread(){
	ll s=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		s=s*10+ch-'0';
		ch=getchar();
	}
	return s*f;
}
const int N = 1e5+7;
const int M = 1e6+1e5+7;
const int maxm=2e5+7;
int hson[N],sz[N],flag;
int mp[M][21][2];
int a[N],n;
ll ans=0;
vector  v[M];
set s[M];
struct edge {
	int head[N], to[maxm], nex[maxm], tot;
	void init() {
		tot = 0;
		rp(i,0,n) head[i] = -1;
	}
	void addedge(int u, int v) {
		to[tot] = v;
		nex[tot] = head[u];
		head[u] = tot++;
	}
}E;
void dfs1(int u,int f){
	sz[u]=1;
	for(int i=E.head[u];~i;i=E.nex[i]){
		int v=E.to[i];
		if(v==f) continue;
		dfs1(v,u);
		sz[u]+=sz[v];
		if(sz[hson[u]]>i)&1]+=k;
	// if(k==1){
		// v[a[u]].push_back(u);
		// s[a[u]].insert(u);
	// }
	// else{
	// 	vector::iterator it=v[a[u]].begin();
	// 	for(;it!=v[a[u]].end();it++){
	// 		if(*it==u){
	// 			v[a[u]].erase(it);
	// 			break;
	// 		}
	// 	}
		// s[a[u]].erase(u);
	// }
	for(int i=E.head[u];~i;i=E.nex[i]){
		int v=E.to[i];
		if(flag!=v&&v!=fa)
			add(v,u,k);
	}
}
void calc(int u,int f,int lca,int k){
	ll val=(a[u]^a[lca]);
	if(k==1){
		rp(i,0,20) ans+=mp[val][i][!((u>>i)&1)]*(1ll<>i)&1]+=k;
		// if(k==1){
			// v[a[u]].push_back(u);
			// s[a[u]].insert(u);
		// }
		// else{
		// 	vector::iterator it=v[a[u]].begin();
		// 	for(;it!=v[a[u]].end();it++){
		// 		if(*it==u){
		// 			v[a[u]].erase(it);
		// 			break;
		// 		}
		// 	}
			// s[a[u]].erase(u);
		// }
	}
	for(int i=E.head[u];~i;i=E.nex[i]){
		int v=E.to[i];
		if(flag==v||v==f) continue;
		calc(v,u,lca,k);
		if(u==lca) add(v,u,k);
	}
}
void dfs(int u,int f,int keep){
	for(int i=E.head[u];~i;i=E.nex[i]){
		int v=E.to[i];
		if(v==f||v==hson[u]) continue;
		dfs(v,u,0);
	}
	if(hson[u]) dfs(hson[u],u,1),flag=hson[u];
	calc(u,f,u,1);
	if(hson[u]) flag=0;
	if(!keep) calc(u,f,u,-1);
}
int main(){
	n=read();
	E.init();
	rp(i,1,n) a[i]=read();
	rp(i,1,n-1){
		int u=read(),v=read();
		E.addedge(u,v);
		E.addedge(v,u);
	}
	dfs1(1,0);
	dfs(1,0,1);
	cout<

 

你可能感兴趣的:(dsu,on,tree,dsu,on,tree,进制拆分)