【HNOI】 攻城略池 tree-dp

  【题目大意】

    给定一棵树,边有边权,每个节点有一些兵,现在叶子节点在0时刻被占领,并且任意节点在x被占领,那么从x+1开始,每单位时间产生一个兵,兵会顺着父亲节点一直走到根(1),其中每经过一个节点,该节点的兵储量减少1,问所有节点都被攻陷的最晚时间。

  【数据范围】

    n<=10^5.

  首先我们可以设每个节点被攻占的时间为w[i],那么对于一个父节点x,我们可以二分一个答案,来判断这个答案是否合法,那么假设我们分到的值为time,那么对于x子树中所有的节点p,每个节点的贡献为max(0,time-w[p]-d[x][p])。那么我们只需要判断贡献和与x节点的兵的数量就可以了。

  那么对于一个点,我们可以用一颗平衡树来存这个点为根节点的子树中所有节点的w[p]+d[x][p]值,那么对于一个二分到的答案,我们只需要判断树中小于time的和,再用time*size减去就好了。对于一个节点,我们只需要启发式合并他所有的子节点的平衡树就好了。

  反思:开始写的时候没看long long,结果发现连第二个测试点都过不去,然后开了之后就出了各种各样的问题,开始是return 0没有改成return 1LL,后来发现平衡树中维护的sum值有的时候没有被更新,我是在插入和删除的时候修改的这个,和size一起修改,也不知道哪儿错了,后来就直接在询问的时候维护sum,然后还是不行,后来能加上的地方都加上了才过了。。。。

  

//By BLADEVIL

#include <cstdio>

#include <algorithm>

#define maxn 100010

#define LL long long



using namespace std;



LL n,l,tot,save;

LL a[maxn],pre[maxn<<1],other[maxn<<1],last[maxn],len[maxn<<1];

LL w[maxn],flag[maxn],que[maxn],dis[maxn],rot[maxn];

LL left[maxn<<5],right[maxn<<5],key[maxn<<5],size[maxn<<5],sum[maxn<<5];



void connect(LL x,LL y,LL z) {

    pre[++l]=last[x];

    last[x]=l;

    other[l]=y;

    len[l]=z;

}



void left_rotate(LL &t) {

    LL k=right[t];

    right[t]=left[k];

    left[k]=t;

    size[k]=size[t];

    sum[k]=sum[t];

    size[t]=size[left[t]]+size[right[t]]+1LL;

    sum[t]=sum[left[t]]+sum[right[t]]+key[t];

    t=k;

}



void right_rotate(LL &t) {

    LL k=left[t];

    left[t]=right[k];

    right[k]=t;

    size[k]=size[t];

    sum[k]=sum[t];

    size[t]=size[left[t]]+size[right[t]]+1LL;

    sum[t]=sum[left[t]]+sum[right[t]]+key[t];

    t=k;

}



void maintain(LL &t,int flag) {

    if (!flag) {

        if (size[left[left[t]]]>size[right[t]])

            right_rotate(t); else 

        if (size[right[left[t]]]>size[right[t]])

            left_rotate(left[t]),right_rotate(t); else return ;

    } else {

        if (size[right[right[t]]]>size[left[t]])

            left_rotate(t); else 

        if (size[left[right[t]]]>size[left[t]])

            right_rotate(right[t]),left_rotate(t); else return ;

    }

    maintain(left[t],0); maintain(right[t],1);

    maintain(t,1); maintain(t,0);

    //sum[t]=sum[left[t]]+sum[right[t]]+key[t];

}



void t_insert(LL &t,LL v) {

    if (!t) {

        t=++tot;

        left[t]=right[t]=0LL;

        size[t]=1LL;

        key[t]=sum[t]=v;

    } else {

        size[t]++; sum[t]+=v;

        if (v<key[t]) t_insert(left[t],v); else t_insert(right[t],v);

        maintain(t,v>=key[t]);

    }

}

/*

LL t_delete(LL &t,LL v) {

    size[t]--;

    if ((v==key[t])||((v>key[t])&&(!right[t]))||((v<key[t])&&(!left[t]))) {

        save=key[t];

        if ((!left[t])||(!right[t]))

            t=left[t]+right[t]; else key[t]=t_delete(left[t],v+1LL);

    } else {

        if (v<key[t]) return t_delete(left[t],v); else return t_delete(right[t],v);

    }

    sum[t]=sum[left[t]]+sum[right[t]]+key[t];

    return save;

}

*/



LL t_delete(LL &t,LL v) {

    if ((v==key[t])||((v>key[t])&&(!right[t]))||((v<key[t])&&(!left[t]))) {

        save=key[t];

        if ((!left[t])||(!right[t])) {

            t=left[t]+right[t]; 

            sum[t]=sum[left[t]]+sum[right[t]]+key[t];

        }else key[t]=t_delete(left[t],v+1LL);

        //tmp = key[t];

    } else {

        if (v<key[t]) save =  t_delete(left[t],v); else save = t_delete(right[t],v);

    }

    //size[t]=size[left[t]]+size[right[t]]+1;

    sum[t]=sum[left[t]]+sum[right[t]]+key[t];

    return save;

}



void combine(LL &t1,LL &flag1,LL t2,LL flag2) {

    if (size[t1]<size[t2]) swap(t1,t2),swap(flag1,flag2);

    while (t2) {

        t_insert(t1,key[t2]+flag2-flag1);

        t_delete(t2,key[t2]);

    }

}



LL judge(LL t,LL time){

    sum[t]=sum[left[t]]+sum[right[t]]+key[t];

    if (!t) return 0LL;

    if (key[t]<=time) 

        return judge(right[t],time)+(size[left[t]]+1LL)*time-sum[left[t]]-key[t]; else 

        return judge(left[t],time);

}



void work() {

    LL h=0LL,t=1LL;

    que[1]=1LL; dis[1]=1LL;

    while (h<t) {

        LL cur=que[++h];

        for (LL p=last[cur];p;p=pre[p]) {

            if (dis[other[p]]) continue;

            que[++t]=other[p];

            dis[other[p]]=dis[cur]+1LL;

        }

    }

    //for (LL i=1;i<=n;i++) printf("%d ",que[i]); printf("\n");

    for (LL i=n;i;i--) {

        LL cur=que[i];

        for (LL p=last[cur];p;p=pre[p]) {

            if (dis[other[p]]<dis[cur]) continue;

            combine(rot[cur],flag[cur],rot[other[p]],flag[other[p]]+len[p]);

        }

        if ((!rot[cur])||(!a[cur])) {

            t_insert(rot[cur],-flag[cur]);

            //printf("%lld %lld %lld\n",cur,rot[cur],flag[cur]);

            //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);

            continue;

        }

        //printf("%lld %lld %lld\n",cur,rot[cur],flag[cur]);

        //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);

        LL l=1LL,r=1LL<<30,mid,ans;

        while (l<=r) {

            //printf("%d %d\n",l,r);    

            mid=l+r>>1LL;

            //if (cur==1) printf("%lld %lld\n",l,r);

            if (judge(rot[cur],mid-flag[cur])>=a[cur]) r=mid-1LL,ans=mid; else l=mid+1LL;

        }

        w[cur]=ans;

        //if (cur==1) printf("|%lld\n",judge(rot[cur],-6));

        t_insert(rot[cur],w[cur]-flag[cur]);

        //printf("%lld %lld %lld\n",cur,rot[cur],flag[cur]);

        //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);

    }

    //printf("%lld %lld\n",rot[1],flag[1]);

    //for (LL i=1;i<=20;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);

    //for (LL i=1;i<=n;i++) printf("%lld ",w[i]); printf("\n");

    LL ans=0LL;

    for (int i=1;i<=n;i++) ans=max(ans,w[i]);

    printf("%lld\n",ans);

}



void check() {

    LL t1=0,t2=0,flag1=0,flag2=0;

    for (LL i=1;i<=5;i++) t_insert(t1,i),t_insert(t2,i); t_insert(t2,6);

    combine(t1,flag1,t2,flag2);

    printf("%lld\n",t1);

    for (LL i=1;i<=20;i++) printf("%ld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);

    return ;

    LL t=0;

    for (LL i=1;i<=10;i++) t_insert(t,i);

    t_delete(t,7); printf("%lld\n",t);

    for (LL i=1;i<=10;i++) printf("%lld %lld %lld %lld %lld %lld\n",i,left[i],right[i],size[i],key[i],sum[i]);

    printf("%lld\n",judge(t,9));

}



int main() {

    //check(); return 0;

    freopen("conquer.in","r",stdin); freopen("conquer.out","w",stdout);

    scanf("%lld",&n);

    for (LL i=1;i<=n;i++) scanf("%lld",&a[i]);

    for (LL i=1;i<n;i++) {

        LL x,y,z; scanf("%lld%lld%lld",&x,&y,&z);

        connect(x,y,z); connect(y,x,z);

    }

    work();

    fclose(stdin); fclose(stdout);

    return 0;

}

 

 

你可能感兴趣的:(tree)