牛客算法周周练15-D树上求和【dfs序上建线段树】【模运算的问题】

题目链接

此算法暴露了一个问题:模运算求答案的错误示范

{//选看,与题目做法无关

要求求出ans(mod p)
你有一个算法能够实现求出来ans2;
但是千万不要在算ans
2的过程中使用模数p。然后对结果除以二,这是不行的。

举个例子:
求出来了res = 7p+16 = ans2;
然后res/=2 = 3*p + (8+p/2) = ans;
这就出问题了,因为答案是8+p/2,而不是8;

为了解决这个问题我们要同步模数:
计算ans2时候使用p=2p作为模数;

如此算得模数为2p的结果:res=3*(2p)+16+p;
然后对res/=2 得到ans = 8+p/2,得到正确答案了;

再深入一点探究一下:
这种情况下只有在模数的系数是不被2整除的情况才出现的,可以拓展为:
可以求出来ansk情况下
计算过程中模数为p
res = (a
k+1)* p+ * b k; // a,b是任意常数
ans = b;(错误)
计算过程中模数为kp
res = akp+(b*k+p);
ans = res/k %p = (b+p/k)%p;(正确)

}

题目做法如下:

首先是题意,维护子树的所有结点的平方和。展开一下就知道:
有结点数a b c;
res=a^ 2+b^ 2+c^2;
如果按照题目操作,对所有结点加上y,有:
res’=(a+y)^ 2 +(b+y)^ 2 + (c+y)^ 2 = res+2y(a+b+c)+3y^ 2;
所以对就是加上y的操作可以当成对
res+=2y
sum+cnt*y^2;
其次就是dfs序上建线段树的做法了;

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
//
using namespace std;
const int INF = 0x3f3f3f3f;//1.06e9大小
const long long mod1 = 1e9 + 7;
const int mod2 = 998244353;
const int mod3 = 1e9;
const double PI = 3.14159265;
const double eps =1e-8;
typedef unsigned long long ULL;
typedef long long LL;
#define ms(x, n) memset(x,n,sizeof(x))
#define debug printf("***debug***\n")
#define pii pair
#define fi first
#define se second
/*
*/
const int MAXN=1e5+10;
const LL mod0=2*23333;
struct node
{
	LL l,r;
	LL sum;//和
	LL res;//平方和;
	LL lazy;

}g[8*MAXN];
vector<int >tr[MAXN];
vector<LL >pp[MAXN];
int dx[2*MAXN];//dfs序
int cc=0;
LL v[MAXN];
void dfs(int k,int last)
{
	dx[++cc]=k;
	pp[k].push_back(cc);
	for(int i=0;i<tr[k].size();++i)
	{
		if(tr[k][i]!=last)dfs(tr[k][i],k);
	}
	dx[++cc]=k;
	pp[k].push_back(cc);
}
void build(int k,LL l,LL r)
{
	g[k].l=l;
	g[k].r=r;
	if(l==r)
	{
		g[k].sum=v[dx[l]]%mod0;
		g[k].res=v[dx[l]]*v[dx[l]]%mod0;

		return ;
	}
	int mid=(l+r)/2;
	build(k<<1,l,mid);
	build(k<<1|1,mid+1,r);
	g[k].sum=g[k<<1].sum+g[k<<1|1].sum;
	g[k].sum%=mod0;
	g[k].res=g[k<<1].res+g[k<<1|1].res;
	g[k].res%=mod0;

}
void fun(int k,LL y)
{
	g[k].lazy+=y;
	g[k].lazy%=mod0;

	g[k].res+=y*g[k].sum*2;
	g[k].res+=y*y%mod0*(g[k].r-g[k].l+1)%mod0;
	g[k].res%=mod0;

	g[k].sum+=(g[k].r-g[k].l+1)%mod0*y%mod0;
	g[k].sum%=mod0;
}
void down (int k)
{
	if(!g[k].lazy)return ;
	LL y=g[k].lazy;
	fun(k<<1,g[k].lazy);
	fun(k<<1|1,g[k].lazy);
	g[k].lazy=0;
}
void add(int k,LL l,LL r,LL y)
{
	if(l<=g[k].l&&g[k].r<=r)
	{//当前区间完全在待求区域;
		fun(k,y);
		return ;
	}
	down(k);
	int mid=(g[k].l+g[k].r)/2;
	if(l<=mid)add(k<<1,l,r,y);
	if(mid+1<=r)add(k<<1|1,l,r,y);
	g[k].sum=g[k<<1].sum+g[k<<1|1].sum;
	g[k].sum%=mod0;
	g[k].res=g[k<<1].res+g[k<<1|1].res;
	g[k].res%=mod0;

}
LL que(int k,LL l,LL r)
{
	if(l<=g[k].l&&g[k].r<=r)
	{
		return g[k].res;
	}
	LL mid=(g[k].l+g[k].r)/2;
	LL ans=0;
	down(k);
	if(l<=mid)ans+=que(k<<1,l,r);
	if(mid+1<=r)ans+=que(k<<1|1,l,r);
	return ans%=mod0;
}
int main()
{
	LL n,q;
	cin>>n>>q;
	for(int i=1;i<=n;++i)
	{
		scanf("%lld",&v[i]);
	}
	for(int i=1;i<n;++i)
	{
		int u,v;
		scanf("%d %d",&u,&v);
		tr[u].push_back(v);
		tr[v].push_back(u);
	}
	dfs(1,1);
	build(1,1,2*n);
	//for(int i=1;i<=n;++i)printf("%d %d\n",pp[i][0],pp[i][1]);
	//for(int i=1;i<=2*n;++i)printf("%d ",dx[i]);printf("\n");
	for(int i=0;i<q;++i)
	{
		int flag=0;
		scanf("%d",&flag);
		if(flag-1)
		{
			int k;
			scanf("%d",&k);
			printf("%lld\n",que(1,pp[k][0],pp[k][1])/2%(mod0/2) );
		}
		else
		{
			int k;
			LL y;
			scanf("%d %lld\n",&k,&y);
			add(1,pp[k][0],pp[k][1],y%mod0);
		}
	}
	return 0;
}



你可能感兴趣的:(牛客题解)