bzoj 4353: Play with tree (树链剖分)

题目描述

传送门

题目大意:
给你一棵包含N个节点的树,设每条边一开始的边权为0,现在有两种操作:
1)给出参数U,V,C,表示把U与V之间的路径上的边权变成C(保证C≥0)
2)给出参数U,V,C,表示把U与V之间的路径上的边权加上max(C,路径上边权最小值的相反数)。
你需要统计出每次一操作过后树中边权为0的边有多少条。

题解

树链剖分。边权下放为点权。
然后用线段树维护区间最小值,区间最小值的个数,以及区间中0的数量。
注意覆盖和增加标记的处理

代码

#include
#include
#include
#include
#include
#define N 200003
#define inf 1000000000
using namespace std; 
int tr[N*4],delta[N*4],cover[N*4],mn[N*4],ct[N*4];
int n,m,belong[N],son[N],v[N],nxt[N],point[N],size[N],tot,fa[N],deep[N],pos[N],sz;
void addedge(int x,int y)
{
    tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
    tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs(int x,int f)
{
    deep[x]=deep[f]+1; size[x]=1;
    for (int i=point[x];i;i=nxt[i]){
        if (v[i]==f) continue;
        fa[v[i]]=x;
        dfs(v[i],x);
        size[x]+=size[v[i]];
        if (size[son[x]]x]=v[i];
    }
}
void dfs1(int x,int chain)
{
    pos[x]=++sz; belong[x]=chain;
    if (!son[x]) return;
    dfs1(son[x],chain);
    for (int i=point[x];i;i=nxt[i])
     if (v[i]!=fa[x]&&v[i]!=son[x])
      dfs1(v[i],v[i]);
}
void update(int now)
{
    ct[now]=0; tr[now]=0;
    mn[now]=min(mn[now<<1],mn[now<<1|1]);
    if (mn[now]==mn[now<<1]) ct[now]+=ct[now<<1];
    if (mn[now]==mn[now<<1|1]) ct[now]+=ct[now<<1|1];
    if (mn[now]==0) tr[now]=ct[now];
}
void build(int now,int l,int r)
{
    cover[now]=-inf;
    if (l==r) {
        tr[now]=ct[now]=1; mn[now]=0;
        return; 
    }
    int mid=(l+r)/2;
    build(now<<1,l,mid);
    build(now<<1|1,mid+1,r);
    update(now);
}
void change(int now,int l,int r,int val)
{
    mn[now]=val; ct[now]=(r-l+1);
    if (mn[now]==0) tr[now]=ct[now];
    else tr[now]=0;
    delta[now]=0; cover[now]=val;
}
void add(int now,int l,int r,int val)
{
    mn[now]+=val; 
    if (mn[now]==0) tr[now]=ct[now];
    else tr[now]=0;
    if(cover[now]==-inf) delta[now]+=val;
    else cover[now]+=val;
}
void pushdown(int now,int l,int r)
{
    int mid=(l+r)/2;
    if (cover[now]!=-inf) {
        change(now<<1,l,mid,cover[now]);
        change(now<<1|1,mid+1,r,cover[now]);
        cover[now]=-inf;
    }
    if (delta[now]) {
        add(now<<1,l,mid,delta[now]);
        add(now<<1|1,mid+1,r,delta[now]);
        delta[now]=0;
    }
}
void qjcover(int now,int l,int r,int ll,int rr,int val)
{
    if (ll>rr) return;
    if (ll<=l&&r<=rr) {
        change(now,l,r,val);
        return;
    }
    int mid=(l+r)/2;
    pushdown(now,l,r);
    if (ll<=mid) qjcover(now<<1,l,mid,ll,rr,val);
    if (rr>mid) qjcover(now<<1|1,mid+1,r,ll,rr,val);
    update(now);
}
void qjadd(int now,int l,int r,int ll,int rr,int val)
{
    if (ll>rr) return;
    if (ll<=l&&r<=rr) {
        add(now,l,r,val);
        return;
    }
    int mid=(l+r)/2;
    pushdown(now,l,r);
    if (ll<=mid) qjadd(now<<1,l,mid,ll,rr,val);
    if (rr>mid) qjadd(now<<1|1,mid+1,r,ll,rr,val);
    update(now);
}
void solve(int x,int y,int z)
{
    while (belong[x]!=belong[y]){
        if (deep[belong[x]]y]]) swap(x,y);
        qjcover(1,1,n,pos[belong[x]],pos[x],z);
        x=fa[belong[x]];
    }
    if (pos[x]>pos[y]) swap(x,y);
    qjcover(1,1,n,pos[x]+1,pos[y],z);
}
int qjmin(int now,int l,int r,int ll,int rr)
{
    if (ll>rr) return inf;
    if (ll<=l&&r<=rr) return mn[now];
    int mid=(l+r)/2; int ans=inf;
    pushdown(now,l,r);
    if (ll<=mid) ans=min(ans,qjmin(now<<1,l,mid,ll,rr));
    if (rr>mid) ans=min(ans,qjmin(now<<1|1,mid+1,r,ll,rr));
    return ans;
}
int find(int x,int y)
{
    int ans=inf;
    while (belong[x]!=belong[y]){
        if (deep[belong[x]]y]]) swap(x,y);
        ans=min(ans,qjmin(1,1,n,pos[belong[x]],pos[x]));
        x=fa[belong[x]];
    }
    if (pos[x]>pos[y]) swap(x,y);
    ans=min(ans,qjmin(1,1,n,pos[x]+1,pos[y]));
    return ans;
}
void solve1(int x,int y,int z)
{
    while (belong[x]!=belong[y]){
        if (deep[belong[x]]y]]) swap(x,y);
        qjadd(1,1,n,pos[belong[x]],pos[x],z);
        x=fa[belong[x]];
    }
    if (pos[x]>pos[y]) swap(x,y);
    qjadd(1,1,n,pos[x]+1,pos[y],z);
}
int main()
{
    freopen("a.in","r",stdin);
    freopen("my.out","w",stdout);
    scanf("%d%d",&n,&m);
    for (int i=1;iint x,y; scanf("%d%d",&x,&y);
        addedge(x,y);
    }
    dfs(1,0); dfs1(1,1);
//  for (int i=1;i<=n;i++) cout<<pos[i]<<" "; cout<1,1,n);
    for (int i=1;i<=m;i++) {
        int opt,x,y,z; scanf("%d%d%d%d",&opt,&x,&y,&z);
        if (opt==1) solve(x,y,z);
        if (opt==2) {
            int t=find(x,y);
            solve1(x,y,max(z,-t));
        }
        printf("%d\n",tr[1]-1);
    }
}

你可能感兴趣的:(树链剖分)