题目链接
Flute 很喜欢柠檬。它准备了一串用树枝串起来的贝壳,打算用一种魔法把贝壳变成柠檬。贝壳一共有 N (1 ≤ N
≤ 100,000) 只,按顺序串在树枝上。为了方便,我们从左到右给贝壳编号 1..N。每只贝壳的大小不一定相同,
贝壳 i 的大小为 si(1 ≤ si ≤10,000)。变柠檬的魔法要求,Flute 每次从树枝一端取下一小段连续的贝壳,并
选择一种贝壳的大小 s0。如果 这一小段贝壳中 大小为 s0 的贝壳有 t 只,那么魔法可以把这一小段贝壳变成 s
0t^2 只柠檬。Flute 可以取任意多次贝壳,直到树枝上的贝壳被全部取完。各个小段中,Flute 选择的贝壳大小 s
0 可以不同。而最终 Flute 得到的柠檬数,就是所有小段柠檬数的总和。Flute 想知道,它最多能用这一串贝壳
变出多少柠檬。请你帮忙解决这个问题。
朴素dp是 O(n2) O ( n 2 ) 的
设 dp[i] d p [ i ] 表示前 i i 个可能产生的最大价值,直接暴力转移
一个显然的结论是: 把一段的首尾大小强制相同不会使答案变差
于是我们对于每一种大小分别考虑:
sum[i] s u m [ i ] 表示与 i i 号贝壳大小相同的 i i 的前缀的贝壳总数
首先从前往后 dp 值一定是不降的,对于两个可能的决策点 j1,j2(j1<j2) j 1 , j 2 ( j 1 < j 2 )
由于后面那一坨平方项也是单增的,并且增加得很快,那么当 j1 j 1 的转移由于 j2 j 2 时, j1 j 1 在之后的决策过程中也必定会优于 j2 j 2
一个想法就是开一个栈,每次如果栈顶不优于第2个元素的时候就弹栈
但是可能会出现第3个元素能在更早的时间比栈顶优的情况,综合考虑的话我们就是要维护一个当前位置元素比它上面的元素优时的 sum[i] s u m [ i ] 的最小值 递增的一个单调栈,这样我们每次从栈顶作为决策点就一定是最优的
其实每个决策点就是一个单调递增的二次函数
二分出函数图像的”交点”来做最优决策
代码:
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define Set(a,b) memset(a,b,sizeof(a))
#define POW(a) ((a)*(a))
int n;
const int N=1e5+10;
const int MAXN=10200;
typedef long long ll;
int s[N];
int *st[MAXN];int top[MAXN];
int num[MAXN];
int pool[N<<1];
ll dp[N];
int lst[MAXN],pre[N],sum[N];
inline ll calc(int p,int x,int S){return dp[p-1]+1ll*S*POW(1ll*(x-sum[p]+1));}
#define INF 1e9
inline int query(int p,int q,int S)
{
register int l=1,r=n,pos=n+1;
while(l<=r){
register int mid=l+r>>1;
if(calc(p,mid,S)>=calc(q,mid,S)) r=mid-1,pos=mid;
else l=mid+1;
}
return pos;
}
int main()
{
scanf("%d",&n);
for(register int i=1;i<=n;++i) scanf("%d",&s[i]),++num[s[i]];
int h=0;
for(register int i=0;i2;
}
register ll ans=0;
for(register int i=1;i<=n;++i){
pre[i]=lst[s[i]];lst[s[i]]=i;sum[i]=sum[pre[i]]+1;
register int x=s[i];
while(top[x]>1&&query(st[x][top[x]-1],st[x][top[x]],x)<=query(st[x][top[x]],i,x)) --top[x];
st[x][++top[x]]=i;
register int p,q;
for(p=st[x][top[x]],q=st[x][top[x]-1];top[x]>1;--top[x],p=q,q=st[x][top[x]-1]) if(calc(p,sum[i],x)>calc(q,sum[i],x)) break;
dp[i]=calc(p,sum[i],x);
ans=max(ans,dp[i]);
}
printf("%lld\n",ans);
}