昨天被学军的公开赛虐傻了,才发现自己还不会用CDQ优化DP,吓得赶紧去填坑。。。
普通的CDQ就是对二分操作,计算前半部分的插入对后半部分的询问的影响。
那么如何用CDQ优化DP呢?
看一道例题:
NOI2007 cash
不难推出平方的dp方程:
f[i] = max(f[i-1], f[j]/(R[j]*A[j]+B[j])*R[j]*A[i] + f[j]/(R[j]*A[j]+B[j])*B[i]) (j<i)
复杂度是O(n^2)的
由于这些变量中并不存在单调关系,所以使不能用单调性优化的
观察式子,我们设x[j] = f[j]/(R[j]*A[j]+B[j])*R[j],y[j] = f[j]/(R[j]*A[j]+B[j])
不难吧转移方程写成如下形式
f[i] = max(f[i-1], x[j]*A[i] + y[j]*B[i]) (j<i)
点积的形式?不难发现,这样的最大值一定是在点集(x[j],y[j])的凸包上的点,更确切的说,这个点能与点(A[i],B[i])组成凸包的偏上的切线
很明显我们可以用平衡树来维护上凸壳,然后再凸壳上找到切线。
问题似乎的到了解决,但是这样的变成复杂度太高了,几乎很难在赛场上写出并调试正确。
于是我们就有了一个强有力的替代品——CDQ分治。
我们可以看到dp的过程就是不断的计算出新的f值再用其组成新的点来放进凸包中
其实我们可以把这个过程理解成这样:计算f[1],在平面上插入(x[1],y[1]),计算f[2],在平面上插入(x[2],y[2])……
这样我们就可以用CDQ分治来优化这个过程了
具体过程如下:
Solve(l,r)
Solve(l,mid)
利用f[l~mid]更新f[mid+1,r]
Solve(mid+1,r)
这题就是利用l~mid之间的点组成的凸壳来更新mid+1~r的dp值,由于凸包上的边斜率单调,所以可以离线后把询问的斜率排序,O(n)搞出来
由于用到了排序,所以总的复杂度为O(nlog^2n)
代码很短的,而我写得很丑。。。。
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; const double Eps = 1e-12; const int Maxn = 100005; int n,i,j,cnt1,cnt2,tot; double f[Maxn], R[Maxn]; double A[Maxn], B[Maxn]; int p[Maxn]; struct Point{ double x,y; Point operator -(const Point &a)const{ return (Point){x-a.x, y-a.y}; } bool operator <(const Point &a)const{ return (x<a.x) || (x==a.x && y<a.y); } } dot[Maxn], up[Maxn], dw[Maxn], q[Maxn], con[Maxn]; double mult(Point P1,Point P2){ return P1.x*P2.y-P1.y*P2.x; } double mult(Point P0,Point P1,Point P2){ return mult(P1-P0,P2-P0); } bool cmp(const int &a,const int &b){ return A[a]*B[b]<A[b]*B[a]; } void merge(int l,int r){ for (int i=l;i<=r;i++) q[i] = dot[i]; int mid = (l+r)>>1, k=0; for (int i=l,j=mid+1;i<=mid||j<=r;k++) if (j>r || (i<=mid && q[i]<q[j])) dot[l+k]=q[i++]; else dot[l+k]=q[j++]; } void solve(int l,int r){ if (l==r){ if (f[l-1]>f[l]) f[l] = f[l-1]; dot[l].x = f[l]/(R[l]*A[l]+B[l])*R[l]; dot[l].y = f[l]/(R[l]*A[l]+B[l]); return; } int mid = (l+r)>>1; solve(l,mid); up[cnt1=1] = dot[l]; for (int i=l+1;i<=mid;i++){ while (cnt1>1 && mult(up[cnt1-1],up[cnt1],dot[i])>=-Eps) cnt1--; up[++cnt1] = dot[i]; } dw[cnt2=1] = dot[l]; for (int i=l+1;i<=mid;i++){ while (cnt2>1 && mult(dw[cnt2-1],dw[cnt2],dot[i])<=Eps) cnt2--; dw[++cnt2] = dot[i]; } tot = 0; for (int i=1;i<cnt1;i++) con[++tot] = up[i]; for (int i=cnt2;i>1;i--) con[++tot] = dw[i]; if (tot==0) con[++tot]=dot[l]; if (tot<=2){ for (int i=mid+1;i<=r;i++) for (int j=1;j<=tot;j++) if (f[i]<con[j].x*A[i]+con[j].y*B[i]) f[i]=con[j].x*A[i]+con[j].y*B[i]; } else { for (int i=mid+1;i<=r;i++) p[i] = i; sort(p+mid+1,p+r+1,cmp); for (int i=mid+1,j=1;i<=r;i++){ while ( (con[j%tot+1].y-con[j].y)*B[p[i]] + A[p[i]]*(con[j%tot+1].x-con[j].x) > Eps ) j=j%tot+1; if (f[p[i]]<con[j].x*A[p[i]]+con[j].y*B[p[i]]) f[p[i]]=con[j].x*A[p[i]]+con[j].y*B[p[i]]; } } solve(mid+1,r); merge(l,r); } int main(){ scanf("%d%lf",&n,&f[1]); for (i=1;i<=n;i++) scanf("%lf%lf%lf",&A[i],&B[i],&R[i]); solve(1,n); printf("%.3lf\n",f[n]); return 0; }
同样可以先把方程列出来:
f[i] = min(f[j]+(a[i]-a[j])^2) (j<i && a[j]>=b[i])
与上题不同的是这里多了一个a[j]>=b[i]的限制,我们逐步来分析。
如果a[i]单调,这很明显是可以用单调性来优化的。
我们可以利用CDQ强行上单调性,也就是说把l~mid强行按a值排序再更新mid+1~r——就是强行造出单调队列,再在这个队列里二分找到最优解
题目里面给了不考虑b的限制的部分分,利用上面的方法就可以了
加上b的限制也是很简单的:把mid+1~r这部分按b排序,在插入l~mid的时候,适时地查询更新mid+1~r的解,复杂度不变
#include <cmath> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int Maxn = 100005; typedef long long LL; LL f[Maxn]; int A[Maxn], B[Maxn]; int q[Maxn], p[Maxn]; int n,i,head,tail; bool cmpA(const int &a,const int &b){ return A[a] > A[b]; } bool cmpB(const int &a,const int &b){ return B[a] > B[b]; } LL F(int j,int k){ return ( f[j]+(LL)A[j]*A[j] )-( f[k]+(LL)A[k]*A[k] ); } LL G(int j,int k){ return (A[j]-A[k]); } LL cc(int j,int i){ return f[j] + (LL)(A[j]-A[i])*(A[j]-A[i]); } LL calc(int x){ if (head>tail) return f[n+1]; LL ret = tail; int L = 1, R = tail-1; while (L<=R){ int mid = (L+R)>>1; if (cc(q[mid],x) < cc(q[mid+1],x)) R = mid-1, ret = mid; else L = mid+1; } ret = q[ret]; return cc(ret,x); } void solve(int l,int r){ if (l==r) return; int mid = (l+r)>>1; solve(l,mid); int i,j; for (i=l;i<=r;i++) p[i] = i; sort(p+l,p+mid+1,cmpA); sort(p+mid+1,p+r+1,cmpB); head = 1; tail = 0; for (i=l,j=mid+1;i<=mid;i++){ while (j<=r && B[p[j]]>A[p[i]]){ LL tmp = calc(p[j]); if (tmp<f[p[j]]) f[p[j]] = tmp; j++; } while (head<tail && F(q[tail-1],q[tail])*G(q[tail],p[i]) <= F(q[tail],p[i])*G(q[tail-1],q[tail]) ) tail--; q[++tail] = p[i]; } while (j<=r){ LL tmp = calc(p[j]); if (tmp<f[p[j]]) f[p[j]] = tmp; j++; } solve(mid+1,r); } int main(){ scanf("%d",&n); for (i=1;i<=n;i++) scanf("%d%d",&A[i],&B[i]); //memset(f,127,sizeof(f)); for (int i=1;i<=n+1;i++) f[i] = 1e13; f[0] = 0; solve(0,n); if (f[n]>=f[n+1]) printf("-1\n"); else printf("%.4lf\n",(double)sqrt((long double)f[n]) ); return 0; }