LibreOJ #3059.「HNOI2019」序列 单调栈

题意

给一个长度为 n n n的序列a,每次询问若修改某一个位置,要求找一个长度同样为 n n n的单调不降序列b,最小化 ∑ i = 1 n ( a i − b i ) 2 \sum_{i=1}^n(a_i-b_i)^2 i=1n(aibi)2
n , m ≤ 1 0 5 n,m\le10^5 n,m105

分析

首先如果我们只选一个数 x x x满足 ∑ ( a i − x ) 2 \sum(a_i-x)^2 (aix)2最小,展开后发现这是一个关于 x x x的二次函数,显然当 x = ∑ a i n x=\frac{\sum{a_i}}{n} x=nai时取最优。
那么答案一定是将序列a分成若干部分,每部分的最优解都取这部分的平均数。
如果没有修改,可以逐个把元素插入单调栈中,若栈顶的平均值小于下一个部分的平均值,则把这两个部分合并。
有一个结论是从前往后和从后往前做得到的结果是一样的。
由于只是单次修改,我们可以先把询问离线,先从后往前做出单调栈,每次撤销后缀单调栈并加入前缀单调栈,来得到某个位置的前缀单调栈和后缀单调栈。
由于每个序列的答案唯一,且注意到最终与x分到同一段的一定是单调栈中的整段,我们只需要求出与x分到同一段的区间的左右端点。
考虑从前往后维护单调栈,在逐个加入x和后缀单调栈的元素时,前缀单调栈的元素会被逐个弹出并被合并,于是我们可以用二分来加速这个过程,从而得到左端点。
具体来说就是先二分一个左端点 L L L,注意到若一个右端点 R R R满足合并后的平均值不大于第 R + 1 R+1 R+1部分的平均值,则 R + 1 R+1 R+1必然也满足,所以右端点同样存在二分性。
二分出右端点后,判断是否满足合并后的平均值不小于第 L − 1 L-1 L1部分的平均值,来确定当前左端点是否合法即可。
时间复杂度 O ( m l o g 2 n ) O(mlog^2n) O(mlog2n)

代码

#include
#include
#include
#include
#include
#include
#define pb push_back

typedef long long LL;

const int N=100005;
const int MOD=998244353;

int n,m,ny[N],xx,yy,a[N],sum,ans1[N],ans2[N],top1,top2,lef1[N],rig1[N],lef2[N],rig2[N],ans[N],mo[N][2],cnt;
LL s[N];
struct data{int lef,rig,tim,op;}cha[N*2];
std::vector<int> vec[N];

int add(int x,int y) {return x+y<MOD?x+y:x+y-MOD;}
int dec(int x,int y) {return x-y<0?x-y+MOD:x-y;}
int mul(int x,int y) {return (LL)x*y%MOD;} 

int calc(int l,int r)
{
	int S=(s[r]-s[l-1])%MOD,len=r-l+1;
	if (xx>=l&&xx<=r) S=add(dec(S,a[xx]),yy);
	int p=mul(S,ny[len]);
	return dec(mul(mul(p,p),len),mul(2*p,S));
}

bool cmp(int l1,int r1,int l2,int r2)
{
	LL s1=s[r1]-s[l1-1],s2=s[r2]-s[l2-1];
	int len1=r1-l1+1,len2=r2-l2+1;
	if (xx>=l1&&xx<=r1) s1+=yy-a[xx];
	else if (xx>=l2&&xx<=r2) s2+=yy-a[xx];
	return (LL)s1*len2<(LL)s2*len1;
}

void build()
{
	for (int i=n;i>=1;i--)
	{
		int l=i,r=i;
		while (top2&&cmp(lef2[top2],rig2[top2],l,r))
		{
			cha[++cnt].op=1;cha[cnt].tim=i;
			cha[cnt].lef=lef2[top2];cha[cnt].rig=rig2[top2];
			r=rig2[top2];
			top2--;
		}
		cha[++cnt].op=2;cha[cnt].tim=i;
		top2++;
		lef2[top2]=l;rig2[top2]=r;
		ans2[top2]=add(ans2[top2-1],calc(l,r));
	}
}

int solve(int x,int y)
{
	xx=x;yy=y;
	top1++;top2++;
	lef1[top1]=lef2[top2]=rig1[top1]=rig2[top2]=x;
	ans1[top1]=add(ans1[top1-1],calc(x,x));
	ans2[top2]=add(ans2[top2-1],calc(x,x));
	int l1=1,r1=top1,L,R;
	while (l1<=r1)
	{
		int mid1=(l1+r1)/2,l2=1,r2=top2;
		while (l2<=r2)
		{
			int mid2=(l2+r2)/2;
			if (mid2>1&&cmp(lef2[mid2-1],rig2[mid2-1],lef1[mid1],rig2[mid2])) r2=mid2-1;
			else l2=mid2+1;
		}
		if (mid1>1&&cmp(lef1[mid1],rig2[l2-1],lef1[mid1-1],rig1[mid1-1])) r1=mid1-1;
		else l1=mid1+1,R=l2-1;
	}
	L=l1-1;
	top1--;top2--; 
	return ((LL)sum+MOD-(LL)a[x]*a[x]%MOD+(LL)y*y+(LL)ans1[L-1]+(LL)ans2[R-1]+(LL)calc(lef1[L],rig2[R]))%MOD;
}

int main()
{
	scanf("%d%d",&n,&m);
	ny[0]=ny[1]=1;
	for (int i=2;i<=n;i++) ny[i]=(LL)(MOD-MOD/i)*ny[MOD%i]%MOD;
	for (int i=1;i<=n;i++) scanf("%d",&a[i]),s[i]=s[i-1]+a[i],sum=add(sum,mul(a[i],a[i]));
	for (int i=1;i<=m;i++) scanf("%d%d",&mo[i][0],&mo[i][1]),vec[mo[i][0]].pb(i);
	build();
	printf("%d\n",add(ans2[top2],sum));
	for (int i=1;i<=n;i++)
	{
		while (cnt&&cha[cnt].tim==i)
		{
			if (cha[cnt].op==2) top2--;
			else top2++,lef2[top2]=cha[cnt].lef,rig2[top2]=cha[cnt].rig,ans2[top2]=add(ans2[top2-1],calc(lef2[top2],rig2[top2]));
			cnt--;
		}
		for (int j=0;j<vec[i].size();j++) ans[vec[i][j]]=solve(mo[vec[i][j]][0],mo[vec[i][j]][1]);
		xx=yy=0;
		int l=i,r=i;
		while (top1&&cmp(l,r,lef1[top1],rig1[top1])) l=lef1[top1],top1--;
		top1++;
		lef1[top1]=l;rig1[top1]=r;
		ans1[top1]=add(ans1[top1-1],calc(l,r));
	}
	for (int i=1;i<=m;i++) printf("%d\n",ans[i]);
	return 0;
}

你可能感兴趣的:(单调队列&单调栈)