首先分治,然后答案转变为求区间[l,r]中,经过终点mid=(l+r)>>1的子串[x,y]的答案之和。
那么不妨枚举左端点为x,那么显然可以得到区间[x,mid]的最小值u和最大值v。同时维护两个指针j,k,表示最远的j使得[mid+1,j]的最小值没有u小,k维护最大值。那么:
1.对于所有y∈[mid+1,min(j,k)],最小值为u,最大值为v,直接利用高斯求和得到答案;
2.对于所有y∈[max(j,k)+1,r],最小值为[mid+1,y]的最小值,最大值为[mid+1,y]的最大值,预处理所有[mid+1,r]的答案得到;
3.不妨设j<k,那么对于所有y∈[j+1,k],最小值为[mid+1,y]的最小值,最大值为v,预处理所有[mid+1,r]的最大值为这部分答案的影响即可。
时间O(NlogN)。
AC代码如下:
#include<iostream> #include<cstdio> #include<cstring> #define mod 1000000000 #define ll long long #define N 500005 using namespace std; int n,ans,a[N],c[N][2],f[N],g[N],p[N][2],q[N][2]; int read(){ int x=0; char ch=getchar(); while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); } return x; } void ad(int &x,int y){ x+=y; if (x>=mod) x-=mod; } void dl(int &x,int y){ x-=y; if (x<0) x+=mod; } int getsum(int x,int y){ return ((ll)(x+y)*(y-x+1)>>1)%mod; } void solve(int l,int r){ if (l==r){ ad(ans,(ll)a[l]*a[l]%mod); return; } int mid=(l+r)>>1,i; solve(l,mid); solve(mid+1,r); c[mid][0]=c[mid][1]=a[mid]; for (i=mid-1; i>=l; i--){ c[i][0]=min(c[i+1][0],a[i]); c[i][1]=max(c[i+1][1],a[i]); } int mn=mod,mx=-mod; f[mid]=g[mid]=p[mid][0]=p[mid][1]=q[mid][0]=q[mid][1]=0; for (i=mid+1; i<=r; i++){ mn=min(mn,a[i]); mx=max(mx,a[i]); f[i]=(ll)mn*mx%mod*(i-mid)%mod; ad(f[i],f[i-1]); g[i]=(ll)mn*mx%mod; ad(g[i],g[i-1]); p[i][0]=(p[i-1][0]+mn)%mod; q[i][0]=(q[i-1][0]+mx)%mod; p[i][1]=(ll)mn*(i-mid)%mod; ad(p[i][1],p[i-1][1]); q[i][1]=(ll)mx*(i-mid)%mod; ad(q[i][1],q[i-1][1]); } int j=mid,k=mid; for (i=mid; i>=l; i--){ while (j<r && c[i][0]<a[j+1]) j++; while (k<r && c[i][1]>a[k+1]) k++; ad(ans,(ll)c[i][0]*c[i][1]%mod*getsum(mid-i+2,min(j,k)-i+1)%mod); ad(ans,((ll)g[r]*(mid-i+1)+f[r])%mod); dl(ans,((ll)g[max(j,k)]*(mid-i+1)+f[max(j,k)])%mod); if (j<k){ ad(ans,((ll)p[k][0]*(mid-i+1)+p[k][1])%mod*c[i][1]%mod); dl(ans,((ll)p[j][0]*(mid-i+1)+p[j][1])%mod*c[i][1]%mod); } else{ ad(ans,((ll)q[j][0]*(mid-i+1)+q[j][1])%mod*c[i][0]%mod); dl(ans,((ll)q[k][0]*(mid-i+1)+q[k][1])%mod*c[i][0]%mod); } } } int main(){ n=read(); int i; for (i=1; i<=n; i++) a[i]=read(); solve(1,n); printf("%d\n",ans); return 0; }
by lych
2016.4.20