bzoj 3252: 攻略 (线段树+DFS序)

题目描述

传送门

题目大意:选出K条从叶子节点到根节点的路径,使路径上的权值之和最大。注意每个点的权值只能被计算一次。

题解

看到这道题的第一反应是最大费用最大流。对于每条边只有第一次流的时候有价值。
那么根据这个思路,我们其实就是每次选取一条权值之和最大的路径加入答案,因为每个点的权值只能计算一次,所以路径上的点子树中所有叶子几点都要减去这个点的价值,就是用线段树维护每个点到根的距离。
按照只有叶子节点的dfs序建树,这样每次修改就是连续区间了。

代码

#include
#include
#include
#include
#include
#define LL long long 
#define N 400003
using namespace std;
int tot,nxt[N],point[N],v[N],pos[N],l[N],r[N],pd[N],mark[N],sz,n,m,fa[N];
LL tr[N*4],val[N],sum[N],delta[N];
void add(int x,int y)
{
    tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
    tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs(int x,int f)
{
    sum[x]=sum[f]+val[x];
    bool pd=false;
    for (int i=point[x];i;i=nxt[i]) {
        if (v[i]==f) continue;
        pd=true;
    }
    if (!pd) {
        l[x]=r[x]=++sz;
        pos[sz]=x;
        return;
    }
    l[x]=sz+1;
    for (int i=point[x];i;i=nxt[i]) {
        if (v[i]==f) continue;
        fa[v[i]]=x;
        dfs(v[i],x);
    }
    r[x]=sz;
}
void update(int now)
{
    if (tr[now<<1]>tr[now<<1|1]) tr[now]=tr[now<<1],mark[now]=mark[now<<1];
    else tr[now]=tr[now<<1|1],mark[now]=mark[now<<1|1];
}
void build(int now,int l,int r)
{
    if (l==r) {
        tr[now]=sum[pos[l]];
        mark[now]=pos[l];
        return;
    }
    int mid=(l+r)/2;
    build(now<<1,l,mid);
    build(now<<1|1,mid+1,r);
    update(now);
}
void pushdown(int now){
    if (delta[now]){
        tr[now<<1]+=delta[now]; delta[now<<1]+=delta[now];
        tr[now<<1|1]+=delta[now]; delta[now<<1|1]+=delta[now];
        delta[now]=0;
    }
}
void query(int now,int l,int r,int ll,int rr,LL v)
{
    if (ll<=l&&r<=rr) {
        tr[now]-=v;
        delta[now]-=v;
        return;
    }
    int mid=(l+r)/2;
    pushdown(now);
    if (ll<=mid) query(now<<1,l,mid,ll,rr,v);
    if (rr>mid) query(now<<1|1,mid+1,r,ll,rr,v);
    update(now);
}
void solve(int x)
{
    while (x) {
        if (pd[x]) break;
        pd[x]=1;
        query(1,1,sz,l[x],r[x],val[x]);
        x=fa[x];
    }
}
int main()
{
    freopen("a.in","r",stdin);
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++) scanf("%lld",&val[i]);
    for (int i=1;iint x,y; scanf("%d%d",&x,&y);
        add(x,y);
    }
    dfs(1,0);
    build(1,1,sz);
    LL ans=0;
    for (int i=1;i<=m;i++){
        ans+=tr[1]; 
        solve(mark[1]);
    }
    printf("%lld\n",ans);
}

你可能感兴趣的:(线段树)