传送门:https://www.luogu.org/problemnew/show/P2023
这个题目的区间更新有加法和乘法。
所以比裸的线段树难一点点吧,也就仅仅是一点点。
既然存在两个操作,所以我们就要维护两个tag,一个加法一个乘法。
但是pushdown的时候这两个tag怎么pushdown呢?
乘法的优先级显然比加法高,所以我们在mul更新的时候要先pushdown,这是一个要点。第二个要点就是,Pushdown的时候,对于乘法tag,我们可以直接乘上父节点的tag,但是对于加法的,我们要怎么办呢? 我们要先把该结点的tag乘上乘法tag,然后再加上父节点的tag。为什么要这样做呢?因为乘法优先级高的嘛,这样就完事了。
看不懂的话,下面我就来推导一下
假设父节点是ax+b
我们要pushdown左儿子。
我们要乘一个k然后加上c
就变成了k(ax+b)+c
变成了kax+kb+c
变成了(ka)x+(kb+c)
右边的kb+c就变成了add[rt<<1]*mul[rt]+add[rt]。ka就是mul[rt<<1]*mul[rt]。
下面是每次都比分块慢的线段树代码:
(分块写这种区间更新的,不如线段树方便,我就没写分块的代码。)
#include
using namespace std;
typedef long long ll;
const int maxn = 1e6+7;
ll a[maxn];
ll sum[maxn<<2],add[maxn<<2],mul[maxn<<2];
ll n,p;
void pushup(int rt)
{
sum[rt] = (sum[rt<<1]+sum[rt<<1|1])%p;
}
void pushdown(int rt,int ln,int rn)
{
if(add[rt] || mul[rt]!=1)
{
sum[rt<<1] = (sum[rt<<1]*mul[rt]+add[rt]*ln)%p;
sum[rt<<1|1] = (sum[rt<<1|1]*mul[rt]+add[rt]*rn)%p;
add[rt<<1] = (add[rt<<1]*mul[rt]+add[rt])%p;
add[rt<<1|1] = (add[rt<<1|1]*mul[rt]+add[rt])%p;
mul[rt<<1] = (mul[rt<<1]*mul[rt])%p;
mul[rt<<1|1] = (mul[rt<<1|1]*mul[rt])%p;
add[rt] = 0;
mul[rt] = 1;
}
}
void build(int rt,int l,int r)
{
mul[rt] = 1;
if(l==r)
{
sum[rt] = a[l]%p;
return;
}
int mid = (l+r)/2;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void Add(int x,int y,int l,int r,int rt,int v)
{
if(x<=l && y>=r)
{
sum[rt] = (sum[rt]+(r-l+1)*v)%p;
add[rt] = (add[rt]+v)%p;
return;
}
int mid = (l+r)/2;
pushdown(rt,mid-l+1,r-mid);
if(x<=mid)
{
Add(x,y,l,mid,rt<<1,v);
}
if(y>mid)
{
Add(x,y,mid+1,r,rt<<1|1,v);
}
pushup(rt);
}
void Mul(int x,int y,int l,int r,int rt,int v)
{
int mid = (l+r)/2;
pushdown(rt,mid-l+1,r-mid);
if(x<=l && y>=r)
{
sum[rt] = (sum[rt]*v)%p;
mul[rt] = (mul[rt]*v)%p;
return;
}
if(x<=mid)
{
Mul(x,y,l,mid,rt<<1,v);
}
if(y>mid)
{
Mul(x,y,mid+1,r,rt<<1|1,v);
}
pushup(rt);
}
ll query(int x,int y,int l,int r,int rt)
{
if(x<=l && y>=r)
{
return sum[rt];
}
int mid = (l+r)/2;
pushdown(rt,mid-l+1,r-mid);
ll ans = 0;
if(x<=mid)
{
ans = (ans+query(x,y,l,mid,rt<<1))%p;
}
if(y>mid)
{
ans = (ans+query(x,y,mid+1,r,rt<<1|1))%p;
}
return ans;
}
int main()
{
scanf("%lld%lld",&n,&p);
for(int i=1;i<=n;i++)
{
scanf("%lld",a+i);
}
build(1,1,n);
int m;
scanf("%d",&m);
for(int i=0;i