显然易得dp方程:f[x]=min{f[y]+(d[x]-d[y])*px+qx},其中y是x的祖先且d[x]-d[y]<=lx。
然后就可以得到对于一定定点z(就是z在等式左边),两个点x,y(d[x]>d[y])且y更优的条件为:(f[x]-f[y])/(d[x]-d[y])>=pz,那么令不等式左边那个为S(x,y)即(x,y)的斜率。
考虑一条链的情况,显然可以离线后cdq分治。
树上的话,不妨称之为树上cdq分治(好吧承认是我自己yy的名字)。就找到重心G,然后维护G到root的f[]的一个凸壳,来更新G和G的子树。时间复杂度O(Nlog^2N)。
AC代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long #define N 200005 using namespace std; int n,tp,all,fst[N],nxt[N],fa[N],sz[N],q[N],stk[N]; ll len[N],f[N],a[N],b[N],c[N],d[N]; double slp[N]; bool ok[N]; ll read(){ ll x=0; char ch=getchar(); while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); } return x; } bool cmp(const int &x,const int &y){ return d[x]-c[x]>d[y]-c[y]; } int findct(int x){ int y,u=x,v; sz[x]=1; for (y=fst[x]; y; y=nxt[y]) if (ok[y]){ v=findct(y); sz[x]+=sz[y]; if ((sz[y]<<1)>all) u=v; } return u; } void dfs(int x){ int y; stk[++tp]=x; for (y=fst[x]; y; y=nxt[y]) if (ok[y]) dfs(y); } double getk(int x,int y){ return (double)(f[x]-f[y])/(d[x]-d[y]); } void calc(int x,int r){ int l=1,mid; while (l<r){ mid=(l+r)>>1; if (slp[mid]>=a[x]) l=mid+1; else r=mid; } f[x]=min(f[x],f[q[l]]+(d[x]-d[q[l]])*a[x]+b[x]); } void solve(int rt){ int x=findct(rt),y; ok[x]=0; if (x!=rt){ all=sz[rt]-sz[x]; solve(rt); tp=0; dfs(x); sort(stk+1,stk+tp+1,cmp); int i=1,now=fa[x],tail=1; q[1]=now; while (i<=tp && d[stk[i]]-d[q[1]]>c[stk[i]]) i++; for (; i<=tp; i++){ while (now!=rt && d[stk[i]]-d[fa[now]]<=c[stk[i]]){ now=fa[now]; while (tail>1 && getk(q[tail],now)>=slp[tail-1]) tail--; slp[tail]=getk(q[tail],now); q[++tail]=now; } calc(stk[i],tail); } } if (x==rt){ tp=0; dfs(x); } int i; for (i=1; i<=tp; i++) if (d[y=stk[i]]-d[x]<=c[stk[i]]) f[y]=min(f[y],f[x]+(d[y]-d[x])*a[y]+b[y]); for (y=fst[x]; y; y=nxt[y]) if (ok[y]){ all=sz[y]; solve(y); } } int main(){ scanf("%d",&n); int i; scanf("%d",&i); for (i=2; i<=n; i++){ scanf("%d",&fa[i]); nxt[i]=fst[fa[i]]; fst[fa[i]]=i; len[i]=read(); a[i]=read(); b[i]=read(); c[i]=read(); d[i]=d[fa[i]]+len[i]; } for (i=1; i<=n; i++) ok[i]=1; memset(f,0x3f,sizeof(f)); f[1]=0; all=n; solve(1); for (i=2; i<=n; i++) printf("%lld\n",f[i]); }
by lych
2016.5.4