BJOI2018求和

master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 k 次方和,而且每次的 k 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。 他把这个问题交给了 pupil,但 pupil 并不会这么复杂的操作,你能帮他解决吗?

输入
第一行包含一个正整数 n,表示树的节点数。

之后 n−1行每行两个空格隔开的正整数 i,j,表示树上的一条连接点 i 和点 j的边。

之后一行一个正整数 m,表示询问的数量。

之后每行三个空格隔开的正整数 i,j,k,表示询问从点 i 到点 j 的路径上所有节点深度的 k 次方和。由于这个结果可能非常大,输出其对 998244353 取模的结果。

树的节点从 1 开始标号,其中 1 号节点为树的根。

输出
对于每组数据输出一行一个正整数表示取模后的结果。

样例输入 [复制]
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
样例输出 [复制]
33
503245989

预处理1-n的1-50次方前缀和。最后用两端的值减去lca的值即可

#include
using namespace std;
#define int long long 
inline int read()
{
    int data=0;int w=1; char ch=0;
    ch=getchar();
    while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
    if(ch=='-') w=-1,ch=getchar();
    while(ch>='0' && ch<='9') data=(data<<3)+(data<<1)+ch-'0',ch=getchar();
    return data*w;
}
const int N=511010;
const int mod=998244353;
int sum[51][N];
struct node{
	int v,nxt;
}e[N<<1];
int fir[N],cnt=0;
inline void add(int u,int v){ e[++cnt]=(node){v,fir[u]};fir[u]=cnt;}
inline int quickpow(int a,int b){
	int c=1;
	while(b){
		if(b&1) c=c*a%mod;
		a=a*a%mod;
		b=b>>1;
	}
	return c;
}
int n,m;
int dep[N],f[N][25];
inline void dfs(int u,int fa){
    dep[u]=dep[fa]+1;
    for(int i=0;i<=19;i++){
    	f[u][i+1]=f[f[u][i]][i];
	}
	for(int i=fir[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==fa) continue;
		f[v][0]=u;
		dfs(v,u);
	}
}
inline int lca(int x,int y){
	if(dep[x]<dep[y]) swap(x,y);
	for(int i=20;i>=0;i--){
		if(dep[f[x][i]]>=dep[y]) x=f[x][i];
		if(x==y) return x;
	}
	for(int i=20;i>=0;i--){
		if(f[x][i]!=f[y][i]){
			x=f[x][i];y=f[y][i];
		}
	}
	return f[x][0];
}
signed main(){
	// int size=100<<20;//40M
    //__asm__ ("movl  %0, %%esp\n"::"r"((char*)malloc(size)+size));//调试用这个 
   // __asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));//提交用这个 

    //main函数代码 

	 n=read();   
	 for(int i=1;i<=50;i++){
     	for(int j=1;j<=n;j++){
     	   sum[i][j]=(sum[i][j-1]+quickpow(j,i))%mod;
		}
	 }
	 for(int i=1;i<n;i++){
	 	int u=read(),v=read();
	 	add(u,v);add(v,u);
	 }
	 dep[0]=-1;
	 dfs(1,0);
	 m=read();
	 for(int i=1;i<=m;i++){
	 	int x=read(),y=read(),k=read();
	 	int L=lca(x,y);
	 	int ans=(sum[k][dep[x]]+sum[k][dep[y]]-sum[k][dep[L]]-sum[k][max(dep[f[L][0]],0ll)]+mod+mod)%mod;
	    printf("%lld\n",ans);
	 }
	 exit(0);//必须用exit 
	return 0;
}

你可能感兴趣的:(lca)