NOIP2023模拟13联测34 abstract

题目大意

有一棵有 n n n个点的树,每个点有一个点权 a i a_i ai f ( i , j ) f(i,j) f(i,j) g ( i , j ) g(i,j) g(i,j)分别表示 i i i j j j的路径上的权值与和权值或,求

∑ i = 1 n ∑ j = i n f ( i , j ) g ( i , j ) \sum\limits_{i=1}^n\sum\limits_{j=i}^nf(i,j)^{g(i,j)} i=1nj=inf(i,j)g(i,j)

输出答案模 111121 111121 111121后的值。

为了方便计算,定义 0 0 = 0 0^0=0 00=0

n ≤ 1 0 5 , a i ≤ 1 0 9 n\leq 10^5,a_i\leq 10^9 n105,ai109,给定的树的叶子(度为 1 1 1的点)不超过 100 100 100个。


题解

不难发现,对于树上从 i i i j j j的链,设 k k k为链上的一个点,则本质不同的 f ( i , k ) f(i,k) f(i,k)不超过 log ⁡ A \log A logA个,其中 A A A a i a_i ai的值域。 g ( i , k ) g(i,k) g(i,k)同理。

f ( i , j ) , g ( i , j ) f(i,j),g(i,j) f(i,j),g(i,j)为一个数对,于是对于树上从 i i i j j j的链,设 k k k为链上的一个点,则本质不同的数对有 O ( log ⁡ A ) O(\log A) O(logA)个。

记叶子个数为 k k k,我们对每个点,用 O ( k log ⁡ A ) O(k\log A) O(klogA)维护这个点到每个子树节点的数对(到每个叶子节点的路径上有最多 O ( log ⁡ A ) O(\log A) O(logA)个本质不同的数对),总时间复杂度为 O ( n k log ⁡ A ) O(nk\log A) O(nklogA)

对于挂在一个点上的链的贡献,我们可以将其看作两个叶子节点在上方相交,最多会有 O ( k 2 ) O(k^2) O(k2)次合并,每次合并都是 O ( log ⁡ 3 A ) O(\log^3 A) O(log3A)的(要枚举这条链的两边,总共有 O ( log ⁡ 2 A ) O(\log^2 A) O(log2A)次,还要用快速幂来求对答案的贡献,是 O ( log ⁡ A ) O(\log A) O(logA)的),所以这部分的时间复杂度为 O ( k 2 log ⁡ 3 A ) O(k^2\log^3A) O(k2log3A)

我们考虑如何省下快速幂的 log ⁡ A \log A logA

可以用光速幂优化,用 O ( p p ) O(p\sqrt p) O(pp )来预处理(其中 p p p为模数,即 111121 111121 111121,是一个质数),那么就可以 O ( 1 ) O(1) O(1) f ( i , j ) g ( i , j ) f(i,j)^{g(i,j)} f(i,j)g(i,j)

总时间复杂度为 O ( n k log ⁡ A + k 2 log ⁡ 2 A + p p ) O(nk\log A+k^2\log^2A+p\sqrt p) O(nklogA+k2log2A+pp )

code

#include
using namespace std;
const int N=100000,siz=350;
const long long mod=111121;
int n,vl[N+5];
long long ans=0,pw[mod+5][siz+5];
vector<int>g[N+5];
map<pair<int,int>,int>mp[N+5];
void init(){
	for(int i=1;i<mod;i++){
		pw[i][0]=1;
		for(int j=1;j<=siz;j++) pw[i][j]=pw[i][j-1]*i%mod;
	}
}
long long mi(long long x,long long y){
	x%=mod;y%=mod-1;
	return pw[pw[x][siz]][y/siz]*pw[x][y%siz]%mod;
}
void dfs(int u,int fa){
	mp[u][{vl[u],vl[u]}]=1;
	ans=(ans+mi(vl[u],vl[u]))%mod;
	for(int v:g[u]){
		if(v==fa) continue;
		dfs(v,u);
		for(auto a:mp[u]){
			for(auto b:mp[v]){
				ans=(ans+1ll*a.second*b.second%mod
					*mi(a.first.first&b.first.first,
					a.first.second|b.first.second)%mod)%mod;
			}
		}
		for(auto b:mp[v]){
			mp[u][{vl[u]&b.first.first,vl[u]|b.first.second}]+=b.second;
		}
	}
}
int main()
{
//	freopen("abstract.in","r",stdin);
//	freopen("abstract.out","w",stdout);
	init();
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d",&vl[i]);
	}
	for(int i=1,x,y;i<n;i++){
		scanf("%d%d",&x,&y);
		g[x].push_back(y);g[y].push_back(x);
	}
	dfs(1,0);
	printf("%lld",ans);
	return 0;
}

你可能感兴趣的:(题解,好题,题解,c++)