首先对于任意a[i]都减去一个i转化为单调不下降序列方便操作,设新的d[i]=a[i]-i。然后第一问二分或者BIT水过去,设f[i]表示i为结尾的最长序列。
然后设g[i]为保留i时的最优值,那么可以得到方程g[i]=min{g[j]+a[j]*(i-j)+(i-j)*(i-j+1)/2},其中j<i且d[j]<=d[i]且f[j]+1=f[i]。
然后把和j有关的弄到一起,就可以令c[i]=g[i]-i*(i+1)/2-a[i]*(i+1),然后上式简化为g[i]=min{i*d[j]+c[j]}+a[i]+b[i]+i*(i-1)/2。。得到j<k且k比j优的斜率表达式为: (c[j]-c[k])/(d[k]-d[j])>i。但是条件太多了,只能用CDQ分治。
首先假设用f[i]=x的去更新f[i]=x+1的,并运用CDQ分治。对于区间(l,r),按照d[i]进行排序,对(l,mid)维护一个斜率的凸包,每次二分查找即可。时间复杂度O(Nlog^2N)。
AC代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define inf 1000000000 #define N 200005 #define ll long long using namespace std; int n,m,cnt,cnt1,cnt2,tp,a[N],b[N],d[N],f[N],hash[N]; int fst[N],nxt[N],h[N],u[N],v[N]; struct node{ int x,y; }p[N]; ll gas[N],g[N],c[N]; bool cmp(int x,int y){ return d[x]<d[y] || (d[x]==d[y] && c[x]>c[y]); } int find(int x){ int l=1,r=n+1,mid; while (l<r){ mid=(l+r)>>1; if (hash[mid]<x) l=mid+1; else r=mid; } return l; } void ins(int x,ll y){ for (; x<=n+1; x+=x&-x) c[x]=max(c[x],y); } int getmax(int x){ ll mx=-inf; for (; x; x-=x&-x) mx=max(mx,c[x]); return mx; } void add(int aa,int bb){ nxt[bb]=fst[aa]; fst[aa]=bb; } double getk(int x,int y){ if (!y) return -1e20; return (double)(c[x]-c[y])/(d[y]-d[x]); } void link(int x){ if (tp && d[x]==d[h[tp]]) tp--; while (tp>1 && getk(h[tp],x)>getk(h[tp-1],h[tp])) tp--; h[++tp]=x; } void qry(int x){ if (!tp) return; h[++tp]=0; int l=1,r=tp,mid; while (l<r){ mid=(l+r)>>1; if (getk(h[mid],h[mid+1])>x) l=mid+1; else r=mid; } l=h[l]; g[x]=min(g[x],(ll)x*d[l]+c[l]); tp--; } void solve(int l,int r){ if (l>=r) return; int i,j=1,mid=(l+r)>>1; solve(l,mid); solve(mid+1,r); cnt1=cnt2=tp=0; for (i=l; i<=mid; i++) if (!p[i].y) u[++cnt1]=p[i].x; for (i=mid+1; i<=r; i++) if (p[i].y) v[++cnt2]=p[i].x; if (!cnt1 || !cnt2) return; sort(u+1,u+cnt1+1,cmp); sort(v+1,v+cnt2+1,cmp); for (i=1; i<=cnt2; i++){ while (j<=cnt1 && d[u[j]]<=d[v[i]]) link(u[j++]); qry(v[i]); } } int main(){ scanf("%d",&n); int i; for (i=1; i<=n; i++){ scanf("%d",&a[i]); d[i]=hash[i+1]=a[i]-i; } for (i=1; i<=n; i++) scanf("%d",&b[i]); for (i=1; i<=n+1; i++) c[i]=-inf; sort(hash+1,hash+n+2); f[1]=1; ins(find(0),0); for (i=1; i<=n; i++){ int t=find(d[i]); f[i]=getmax(t)+1; ins(t,(ll)f[i]); } int len=getmax(n+1); printf("%d ",len); ll ans=(ll)inf*inf; memset(fst,-1,sizeof(fst)); for (i=n; i>=0; i--) if (f[i]>=0) add(f[i],i); for (i=1; i<=n; i++) gas[i]=gas[i-1]+i; for (i=fst[0]; i>=0; i=nxt[i]) c[i]=0; for (i=1; i<=n; i++) g[i]=ans; for (i=0; i<len; i++){ int j=fst[i],k=fst[i+1]; cnt=0; while (j>=0 || k>=0) if (j>=0 && (k<0 || j<k)){ p[++cnt].x=j; p[cnt].y=0; j=nxt[j]; } else{ p[++cnt].x=k; p[cnt].y=1; k=nxt[k]; } solve(1,cnt); for (k=fst[i+1]; k>=0; k=nxt[k]){ g[k]+=gas[k-1]+a[k]+b[k]; c[k]=g[k]+gas[k]-(ll)a[k]*(k+1); } } for (i=0; i<=n; i++) if (f[i]==len) ans=min(ans,g[i]+(ll)a[i]*(n-i)+gas[n-i]); printf("%lld\n",ans); return 0; }
2016.2.17