此文大部分来自以下博客
如果有需要学习请移步以下链接
【笔记】【总结】斜率DP及习题-Little_Fall
斜率优化-南枙向暖
此文主要初学的我为了加深自己印象,难免错误,请大家勿往下阅读
引入
当我们在遇到这样的DP方程时 d p [ i ] = d p [ j ] + ( x [ i ] − x [ j ] ) ∗ ( x [ i ] − x [ j ] ) dp[i]=dp[j]+(x[i]-x[j])*(x[i]-x[j]) dp[i]=dp[j]+(x[i]−x[j])∗(x[i]−x[j]),如果把右边的乘法化开的话,会有 x [ i ] ∗ x [ j ] x[i]*x[j] x[i]∗x[j]的项,它不能分解为只与i或j有关的部分。这里学习一种新的优化方法,叫做斜率优化。
例题: HDU3507 Print Article
给定n(5e5)和m(1000),以及一个长为n的数列 a i a_i ai,现在要把数列分成若干个连续段。定义每个连续段的代价是 ∑ a i 2 + m \sum{a_i}^2+m ∑ai2+m,求划分后的最小代价。
思路:
dp[i]: 表示把前i个数划分完的最小代价
sum[i]: 从 a 1 + … … + a i a_1+……+a_i a1+……+ai的和
则我们可以得到初始式子 d p [ i ] = m i n ( d p [ j ] + ( s u m [ i ] − s u m [ j ] ) 2 ) , ( 0 < = j < i , d p [ 0 ] = 0 ) dp[i]=min(dp[j]+(sum[i]-sum[j])^2),(0<=jdp[i]=min(dp[j]+(sum[i]−sum[j])2),(0<=j<i,dp[0]=0)
然后我们通过变形,设j
转化得, d p [ k ] − s u m [ k ] 2 − ( d p [ j ] − s u m [ j ] 2 ) s u m [ k ] − s u m [ j ] < 2 ∗ s u m [ i ] \frac{dp[k]-sum[k]^2-(dp[j]-sum[j]^2)}{sum[k]-sum[j]}<2*sum[i] sum[k]−sum[j]dp[k]−sum[k]2−(dp[j]−sum[j]2)<2∗sum[i]
若 y j = d p [ j ] − s u m [ j ] 2 , x j = s u m [ j ] y_j=dp[j]-sum[j]^2,x_j=sum[j] yj=dp[j]−sum[j]2,xj=sum[j],则 y k − y j ( x k − x j ) < 2 ∗ s u m [ i ] \frac{y_k-y_j}{(x_k-x_j)}<2*sum[i] (xk−xj)yk−yj<2∗sum[i],
未完待续
斜率DP算法流程总结
1.找到dp的转移式,通过斜率分析得到点的表示 ( x i , y i ) (x_i,y_i) (xi,yi)及目标斜率的表示
2.用双端队列维护一个下凸包,每当新来一个点时,末尾不断出队直到构不成上凸包。
3.选择最优解,即 k j , j + 1 k_{j,j+1} kj,j+1大于目标斜率的第一个j。当目标斜率递增时,每次更新队首,然后可以选择队首作为最优解。
4.斜率少用double,容易有误差
#include
#include
#include
#include
#include
#include
#include
#include
#define INF 0x3f3f3f3f
#define lowbit(x) x & -x
#define lson root<<1,l,mid
#define rson root<<1|1,mid+1,r
using namespace std;
typedef long long ll;
const int mod=1e9+7;
const int N=5e5+5;
int n,m,head,tail;
int dp[N],sum[N],q[N];
int getdp(int i,int j){
return dp[j]+m+(sum[i]-sum[j])*(sum[i]-sum[j]);
}
int getup(int j,int k){
return dp[j]+sum[j]*sum[j]-(dp[k]+sum[k]*sum[k]);
}
int getdown(int j,int k){
return 2*(sum[j]-sum[k]);
}
int main(){
#ifdef Mizp
freopen("in.txt","r",stdin);
#endif
while(scanf("%d%d",&n,&m)!=EOF){
sum[0]=dp[0]=0;
for(int i=1;i<=n;i++){
scanf("%d",&sum[i]);
sum[i]+=sum[i-1];
}
head=tail=0;
q[tail++]=0;
for(int i=1;i<=n;i++){
while(head+1<tail && getup(q[head+1],q[head])<=sum[i]*getdown(q[head+1],q[head]))
head++;
dp[i]=getdp(i,q[head]);
while(head+1<tail && getup(i,q[tail-1])*getdown(q[tail-1],q[tail-2])<=getup(q[tail-1],q[tail-2])*getdown(i,q[tail-1]))
tail--;
q[tail++]=i;
}
printf("%d\n",dp[n]);
}
return 0;
}