bzoj 4127: Abs(树链剖分+线段树)

4127: Abs

Time Limit: 40 Sec   Memory Limit: 256 MB
Submit: 366   Solved: 129
[ Submit][ Status][ Discuss]

Description

 给定一棵树,设计数据结构支持以下操作

    1 u v d  表示将路径 (u,v) 加d

    2 u v  表示询问路径 (u,v) 上点权绝对值的和

Input

第一行两个整数n和m,表示结点个数和操作数
接下来一行n个整数a_i,表示点i的权值

接下来n-1行,每行两个整数u,v表示存在一条(u,v)的边

接下来m行,每行一个操作,输入格式见题目描述

Output

对于每个询问输出答案

Sample Input

4 4
-4 1 5 -2
1 2
2 3
3 4
2 1 3
1 1 4 3
2 1 3
2 3 4

Sample Output

10
13
9

HINT

对于100%的数据,n,m <= 10^5 且 0<= d,|a_i|<= 10^8


Source

[ Submit][ Status][ Discuss] 题解:树链剖分+线段树维护区间绝对值之和。

因为D大于0,所有一个负数只可能变成正数一次,那么我们可以维护一下区间中最大的负数,然后每次都通过线段树区间查询的方式找出最大负数及所在的位置,如果最大的负数+d大于,说明区间中负数的个数发生改变,所有需要更新负数的个数,同时把统计区间最大负数的数组的值改为-INF,然后把当前点的sum值改为他的相反数,因为最后算区间绝对值和的时候是用过计算正负数查的方式给区间加上一个值,并打标记,一个负数变成正数,他的绝对值变成v-abs(x),所有在区间修改之前事先找出所有会从负数变成正数的数,通过点修改的方式进行处理。

#include<iostream>    
#include<cstdio>    
#include<cstring>    
#include<algorithm>    
#include<cmath>    
#define N 400003    
#define LL long long
#define inf 1000000000    
using namespace std;    
int n,m,sz,tr[N],w[N];    
int tot,next[N],point[N],v[N],deep[N],son[N],size[N],belong[N],fa[N],pos[N];    
LL val[N],sum[N],a[N],maxn[N],delta[N],cnt[N];  
struct data
{
    int dis;
    LL c;
};
void add(int x,int y)    
{    
    tot++; next[tot]=point[x]; point[x]=tot; v[tot]=y;    
    tot++; next[tot]=point[y]; point[y]=tot; v[tot]=x;    
}    
void dfs(int x,int f,int dep)    
{    
    deep[x]=dep; size[x]=1;     
    for (int i=point[x];i;i=next[i])    
     if (v[i]!=f)    
     {    
        dfs(v[i],x,dep+1);    
        fa[v[i]]=x;    
        size[x]+=size[v[i]];    
        if (size[son[x]]<size[v[i]])    
         son[x]=v[i];    
     }    
}    
void build(int x,int chain)    
{    
    belong[x]=chain;    
    pos[x]=++sz;  val[sz]=a[x];    
    int k=0;    
    if (son[x]==0) return;    
    build(son[x],chain);    
    for (int i=point[x];i;i=next[i])    
     if (v[i]!=fa[x]&v[i]!=son[x])    
      build(v[i],v[i]);    
}    
void update(int x)    
{    
    sum[x]=sum[x<<1]+sum[x<<1|1];    
    cnt[x]=cnt[x<<1]+cnt[x<<1|1];   
    maxn[x]=max(maxn[x<<1],maxn[x<<1|1]);  
    if (maxn[x<<1]>maxn[x<<1|1])  //记录最大负数所在的位置
     tr[x]=tr[x<<1];  
    else  tr[x]=tr[x<<1|1];  
}   
void pushdown(int now,int l,int r)
{
    if (delta[now]==0) return ;
    int mid=(l+r)/2;
    sum[now<<1]+=(mid-l+1-2*cnt[now<<1])*delta[now];
    delta[now<<1]+=delta[now];
    sum[now<<1|1]+=(r-mid-2*cnt[now<<1|1])*delta[now];
    delta[now<<1|1]+=delta[now];
    if (maxn[now<<1]!=-inf) maxn[now<<1]+=delta[now];
    if (maxn[now<<1|1]!=-inf)  maxn[now<<1|1]+=delta[now];
    delta[now]=0;
} 
void buildtree(int now,int l,int r)    
{    
    if (l==r)    
     {    
        if(val[l]>=0)  maxn[now]=-inf,tr[now]=l,cnt[now]=0;   
        else  maxn[now]=val[l],tr[now]=l,cnt[now]=1;
        sum[now]=abs(val[l]); 
        w[l]=now;   
        return;    
     }    
    int mid=(l+r)/2;    
    buildtree(now<<1,l,mid);    
    buildtree(now<<1|1,mid+1,r);    
    update(now);    
}    
void query(int now,int l,int r,int ll,int rr,LL vv)    
{    
    if (l>=ll&&r<=rr)    
     {    
        sum[now]+=(long long)(r-l+1-2*cnt[now])*vv; 
        if (maxn[now]!=-inf) maxn[now]+=vv; 
        delta[now]+=vv;    
        return;    
     }    
    int mid=(l+r)/2;    
    pushdown(now,l,r);    
    if (ll<=mid) query(now<<1,l,mid,ll,rr,vv);    
    if (rr>mid)  query(now<<1|1,mid+1,r,ll,rr,vv);    
    update(now);    
}    
LL qjsum(int now,int l,int r,int ll,int rr)    
{    
    if (l>=ll&&r<=rr)    
     return sum[now];    
    int mid=(l+r)/2;  
    pushdown(now,l,r);  
    LL ans=0;    
    if (ll<=mid) ans+=qjsum(now<<1,l,mid,ll,rr);    
    if (rr>mid) ans+=qjsum(now<<1|1,mid+1,r,ll,rr);    
    return ans;    
}    
data qjmax(int now,int l,int r,int ll,int rr)  //查区间最大值时需要注意,刚开始我想只返回位置,但是发现所在位置的数可能已经改变,但是他所在区间的标记并未下放,也就是只有区间维护的最大值更改了,那么如果返回位置,通过位置上的数判断是否会变为正数,就会出错,那么索性位置和区间最值一起返回,比较时用区间最值,更改时更改返回的位置
{  
    if (l>=ll&&r<=rr)  
    {
     data u;
     u.dis=tr[now];
     u.c=maxn[now];
     return u;
    }
    pushdown(now,l,r);  
    int mid=(l+r)/2;  
    data lc,rc;
    bool p=false,q=false;
    if (ll<=mid)   lc=qjmax(now<<1,l,mid,ll,rr),p=true;
    if (rr>mid)    rc=qjmax(now<<1|1,mid+1,r,ll,rr),q=true;
    if (!p)  return rc;
    if (!q)  return lc;
    if (lc.c>rc.c)  return lc;  
    return rc;  
}  
void pointchange(int now,int l,int r,LL x,LL v)  
{  
    if (l==r)  
     {  
        cnt[now]=0;  maxn[now]=-inf;   
        sum[now]=-sum[now];  
        return;  
     }  
    pushdown(now,l,r);  
    int mid=(l+r)/2;  
    if (x<=mid)  pointchange(now<<1,l,mid,x,v);  
    else  pointchange(now<<1|1,mid+1,r,x,v);  
    update(now);  
}  
void solve1(int x,int y,LL vv)    
{    
    while (belong[x]!=belong[y])    
     {    
        if (deep[belong[x]]<deep[belong[y]]) swap(x,y);    
        data t=qjmax(1,1,n,pos[belong[x]],pos[x]);  
        while (abs(t.c)<=vv)  
         {  
            pointchange(1,1,n,t.dis,vv);  
            t=qjmax(1,1,n,pos[belong[x]],pos[x]); 
         }    
        query(1,1,n,pos[belong[x]],pos[x],vv);  
        x=fa[belong[x]];    
     }    
    if (deep[x]>deep[y]) swap(x,y);  
    data t=qjmax(1,1,n,pos[x],pos[y]);  
        while (abs(t.c)<=vv)  
         {  
            pointchange(1,1,n,t.dis,vv);  
            t=qjmax(1,1,n,pos[x],pos[y]);  
         }      
    query(1,1,n,pos[x],pos[y],vv);    
}    
LL solve2(int x,int y)    
{    
    LL ans=0;    
    while (belong[x]!=belong[y])    
     {    
        if (deep[belong[x]]<deep[belong[y]]) swap(x,y);    
        ans+=qjsum(1,1,n,pos[belong[x]],pos[x]);    
        x=fa[belong[x]];    
     }    
    if (deep[x]>deep[y]) swap(x,y);    
    ans+=qjsum(1,1,n,pos[x],pos[y]);    
    return ans;    
}    
int main()    
{    
    scanf("%d%d",&n,&m);    
    for (int i=1;i<=n;i++)  scanf("%lld",&a[i]);    
    for (int i=1;i<n;i++)    
     {    
        int x,y; scanf("%d%d",&x,&y);    
        add(x,y);    
     }    
    dfs(1,0,1);  
    build(1,1);  
    buildtree(1,1,n); 
    for (int i=1;i<=m;i++)    
     { 
        int op; scanf("%d",&op);    
        int x,y; scanf("%d%d",&x,&y);
        if (op==1)    
         {    
            LL z;    scanf("%lld",&z);    
            solve1(x,y,z);    
         }    
        else  printf("%lld\n",solve2(x,y));      
     }    
}    



你可能感兴趣的:(bzoj 4127: Abs(树链剖分+线段树))