【题目】
LOJ
给定一个长度为 n n n的序列 A A A,以及 m m m个操作,每次操作将一个 A i A_i Ai修改为 k k k,修改是独立的。每次修改后要求求出一个单调不下降的序列 B i B_i Bi,使得 ∑ i = 1 n ( A i − B i ) 2 \sum_{i=1}^n(A_i-B_i)^2 ∑i=1n(Ai−Bi)2最小,并输出最小值。特别地, B B B可以是分数的形式,但答案对 998244353 998244353 998244353取模。
n ≤ 3 × 1 0 5 , m ≤ 1 0 5 , k , A i ≤ 1 0 9 n\leq 3\times 10^5,m\leq 10^5,k,A_i\leq 10^9 n≤3×105,m≤105,k,Ai≤109
【解题思路】
考虑没有修改的时候怎么求。由于要求 B B B单调不降,那么我们不难发现,若 A i > A i + 1 A_i>A_{i+1} Ai>Ai+1,则 B i = B i + 1 B_{i}=B_{i+1} Bi=Bi+1。我们将有这些关系的位置全部缩起来,不难发现, B B B全部取 A A A的平均值最优。那么我们用单调栈维护缩起来的块,使得块的平均值单调不降,每次新加入一个元素我们用栈顶和新元素合并直到满足不下降的性质即可。
还有一个重要的性质是合并的顺序并不会影响最后的结果,那么不妨考虑一个修改会对答案产生什么影响。
首先要将这个位置左右两边的单调栈维护出来,再考虑合并上这个数,下面我们将一个整块看作一个元素。
这个地方我想的时候是猜了一下这个答案关于弹栈次数的东西是一个凸的函数,而且两边的函数叠加还是凸的,然后就可以三分套三分了!(然而这个函数可能会有平的地方,于是凉凉)
考虑左边,如果我们二分一下弹栈次数,显然弹栈次数越大越有可能使栈合法。而若当前弹栈次数不合法(高于后面或低于前面),我们就要和右边进行合并来抬高(或降低)这个段,显然这个也是可以二分的。注意这里两个二分是独立的,即我左边所有都合完了仍不合法,再在右边二分,于是复杂度就是一个log的了。
写着写着发现自己假了,右边二分了一个界以后,左边可能就不需要合并完了,于是还是一个二分套二分。
复杂度 O ( n log 2 n ) O(n\log ^2 n) O(nlog2n)
特别地,我们也许还需要维护一下栈的回退。
【参考代码】
#include
#define pb push_back
using namespace std;
typedef long long ll;
const int N=1e5+10,mod=998244353;
namespace IO
{
int read()
{
int ret=0;char c=getchar();
while(!isdigit(c)) c=getchar();
while(isdigit(c)) ret=ret*10+(c^48),c=getchar();
return ret;
}
void write(int x){if(x>9)write(x/10);putchar(x%10^48);}
void writeln(int x){write(x);putchar('\n');}
}
using namespace IO;
namespace Math
{
int inv[N];
int mul(int x,int y){return 1ll*x*y%mod;}
int sqr(int x){return mul(x,x);}
int upm(int x){return x>=mod?x-mod:(x<0?x+mod:x);}
void up(int &x,int y){x=upm(x+y);}
void initmath(){inv[1]=1;for(int i=2;i<N;++i)inv[i]=mod-mul(mod/i,inv[mod%i]);}
}
using namespace Math;
namespace DreamLolita
{
int n,Q,sum,top1,top2;
int a[N],ans[N],st1[N],st2[N],f[N],g[N];
ll s[N];//s should use long long
vector<int>oper[N];
struct data
{
int val,id;
data(int _v=0,int _i=0):val(_v),id(_i){}
};
vector<data>qr[N];
bool cmp(int l1,int r1,int l2,int r2,int x,int y){return (s[r1]-s[l1]+x)*(r2-l2)>(s[r2]-s[l2]+y)*(r1-l1);}
int calc(int l,int r,int x){return mod-(((ll)s[r]-s[l]+x)%mod)*(((ll)s[r]-s[l]+x)%mod)%mod*inv[r-l]%mod;}
void init()
{
n=read();Q=read();
for(int i=1;i<=n;++i) a[i]=read(),s[i]=s[i-1]+a[i],sum=upm(sum+sqr(a[i]));
qr[1].pb(data(a[1],0));ans[0]=sum;
for(int i=1,x,y;i<=Q;++i)
{
x=read();y=read();qr[x].pb(data(y,i));
ans[i]=upm(upm(sum+mul(mod-a[x],a[x]))+sqr(y));
//cerr<
}
for(int i=1;i<=n;++i)
{
//printf("%d %d %d\n",st1[top1-1],st1[top1],cmp(st1[top1-1],st1[top1],st1[top1],i,0,0));
for(;top1 && cmp(st1[top1-1],st1[top1],st1[top1],i,0,0);) oper[i].pb(st1[top1--]);
st1[++top1]=i;f[top1]=upm(f[top1-1]+calc(st1[top1-1],st1[top1],0));
reverse(oper[i].begin(),oper[i].end());
//printf("%d %d %d\n",i,top1,f[top1]);
}
//cerr<
}
void solve()
{
st2[0]=n;
for(int i=n;i;--i)
{
--top1;
for(auto j:oper[i]) st1[++top1]=j,f[top1]=upm(f[top1-1]+calc(st1[top1-1],j,0));;
if(i^n)
{
for(;top2 && cmp(i,st2[top2],st2[top2],st2[top2-1],0,0);--top2);
st2[++top2]=i;g[top2]=upm(g[top2-1]+calc(i,st2[top2-1],0));
}
//printf("%d %d\n",top1,top2);
//for(int i=1;i<=top1;++i) printf("%d ",st1[top1]); puts("!");
//for(int i=1;i<=top2;++i) printf("%d ",st2[top2]); puts("?");
for(auto j:qr[i])
{
int l=1,r=top1,delta=j.val-a[i],res=0;
while(l<=r)
{
int mid=(l+r)>>1;
if(cmp(st1[mid-1],st1[mid],st1[mid],i,0,delta)) r=mid-1;
else res=mid,l=mid+1;
}
//printf("now:%d\n",res);
if(!top2 || !cmp(st1[res],i,i,st2[top2-1],delta,0)) up(ans[j.id],upm(upm(calc(st1[res],i,delta)+f[res])+g[top2]));
else
{
int l=1,r=top2-1,ret=0;
while(l<=r)
{
int mid=(l+r)>>1;
int L=1,R=res,rep=0;
//cerr<
while(L<=R)
{
//cerr<
int Mid=(L+R)>>1;
if(cmp(st1[Mid-1],st1[Mid],st1[Mid],st2[mid],0,delta)) R=Mid-1; else rep=Mid,L=Mid+1;
}
//printf("res:%d %d %d\n",rep,mid,delta);
if(cmp(st1[rep],st2[mid],st2[mid],st2[mid-1],delta,0)) r=mid-1; else ret=mid,l=mid+1;
}
//printf("zzz:%d\n",ret);
int L=1,R=res,rep=0;
while(L<=R)
{
int Mid=(L+R)>>1;
if(cmp(st1[Mid-1],st1[Mid],st1[Mid],st2[ret],0,delta)) R=Mid-1; else rep=Mid,L=Mid+1;
}
//printf("res:%d %d %d\n",rep,ret,delta);
//cerr<
up(ans[j.id],upm(upm(calc(st1[rep],st2[ret],delta)+f[rep])+g[ret]));
//cerr<
}
}
}
for(int i=0;i<=Q;++i) writeln(ans[i]);
}
void solution()
{
initmath();init();solve();
}
}
int main()
{
#ifdef Durant_Lee
freopen("LOJ3059.in","r",stdin);
freopen("LOJ3059.out","w",stdout);
#endif
DreamLolita::solution();
return 0;
}