HDU 5052(树链剖分+线段树)

链接:http://acm.hdu.edu.cn/showproblem.php?pid=5052

题意:树上每个点有一个权值对应商品价格,每次查询两个点之间的路径,从X走到Y,可以选择在一个点买入商品在另一个点卖出商品(卖出点一定要在买入点的后面),并且每次走过这条路径后,这条路径上的所有点的商品价格都会上涨V。对每个查询求出从走到Y所能获得的最大收益(不小于0,可以选择不买入)。

分析:这道题跟POJ3728(链接:http://poj.org/problem?id=3728)非常相似,但是POJ的题并没有点更新的操作,所以可以离线以后用带权并查集查出答案。但是这道题每次走过后需要对整条链上的点进行更新,所以离线是肯定不行的。但是同样可以仿照那道题的思路,用线段树来维护相应的值。

首先考虑如果是线性的一维数组,求X到Y的最大收益,就是线段树的区间合并问题。每个区间维护一个最大值,最小值,和当前区间的答案,合并区间的时候分别考虑左右子区间对于答案的贡献,以及最小值在左最大值在右的情况,就是合并后区间的答案。解决了一维数组线性情况的问题之后,考虑树上的情况。 

从X走到Y可以把答案分为三种情况,设X和Y的最近公共祖先为LCA

1。在X到LCA的路上买入和卖出

2。在LCA到Y的路上买入和卖出

3.。在X到LCA的路上买入,在LCA到Y的路上卖出

和一维数组不同的地方在于,从LCA到Y用树剖做的话与题意其实是反向的,所以不仅要维护一个从下走到上的答案,还需要维护一个从上走到下的答案。然后每次将新查询出的答案与前一个合并。到X和Y走到同一条重链上的时候,分情况讨论一下X和Y的上下关系,合并方式略有区别。

树上点权更新,由于是区间更新,如果整个区间都要更新,则只会影响到区间的最大值和最小值,而答案完全不会受到影响,所以使用lazy操作,当查询或者更新时下放。

昨天写HDU4718卡了很久,今天1A感觉很开心

代码:

#include 
#include 
#include 
#include 
#define ld d<<1
#define rd d<<1|1
#define lson ld,l,m
#define rson rd,m+1,r
using namespace std;
const int N = 50000 +5;
int n,m;
vectorg[N];
int val[N];
int dfn[N],rnk[N],dep[N],fa[N],son[N],siz[N],top[N];
int tot;
void dfs1(int u,int f,int d)
{
    fa[u] = f;
    son[u] = -1;
    siz[u] = 1;
    dep[u] = d;
    for(int i = 0; i < g[u].size(); i++)
    {
        int v = g[u][i];
        if(v == f) continue;
        dfs1(v,u,d+1);
        siz[u] += siz[v];
        if(son[u] == -1 || siz[v] > siz[son[u]]) son[u] = v;
    }
    return;
}
void dfs2(int u,int t)
{
    dfn[u] = tot;
    rnk[tot++] = u;
    top[u] = t;
    if(son[u] == -1) return;
    dfs2(son[u],t);
    for(int i = 0; i > 1;
    build(lson);
    build(rson);
    tr[d] = merge(tr[ld],tr[rd]);
    return;
}
node query(int d,int l,int r,int ql,int qr)
{
    if(ql <= l && r<= qr)
    {
        return tr[d];
    }
    int m = (l + r) >> 1;
    pushdown(d);
    node ans1 = node(), ans2 = node();
    if(m < ql) return query(rson,ql,qr);
    if(qr <= m) return query(lson,ql,qr);
     ans1 = query(lson,ql,qr); ans2 = query(rson,ql,qr);
    return merge(ans1,ans2);
}
void update(int d,int l,int r,int ql,int qr,int v)
{
    if(ql <= l && r <= qr)
    {
        tr[d].minn += v;
        tr[d].maxx += v;
        tr[d].lz += v;
        return;
    }
    int m = (l+r) >> 1;
    pushdown(d);
    if(ql <= m) update(lson,ql,qr,v);
    if(m  dep[fy])
        {
            if(fl1 == 0)
            {
                ans1 = query(1,1,n,dfn[fx],dfn[x]);
                fl1 = 1;
            }
            else ans1 = merge(query(1,1,n,dfn[fx],dfn[x]),ans1);
            update(1,1,n,dfn[fx],dfn[x],v);
            x = fa[fx], fx = top[x];
        }
        else
        {
            if(fl2 == 0)
            {
                ans2 = query(1,1,n,dfn[fy],dfn[y]);
                fl2 = 1;
            }
            else ans2 = merge(query(1,1,n,dfn[fy],dfn[y]),ans2);
            update(1,1,n,dfn[fy],dfn[y],v);
            y = fa[fy], fy = top[y];
        }
    }
    int res = 0;
    if(dep[x] <= dep[y])
    {
        node ans3 = query(1,1,n,dfn[x],dfn[y]);
        res = ans3.downans;
        if(fl2 != 0)
        {
            ans3 = merge(ans3,ans2);
            res = max(res,ans3.downans);
        }
        if(fl1 != 0)
        {
           res = max(res,max(ans1.upans,ans3.downans));
           res = max(res,ans3.maxx - ans1.minn);
        }
        update(1,1,n,dfn[x],dfn[y],v);
    }
    else
    {
        node ans3 = query(1,1,n,dfn[y],dfn[x]);
        if(fl1 != 0) ans3 = merge(ans3,ans1);
        res = ans3.upans;
        if(fl2 != 0)
        {
            res =max(res,max(ans3.upans,ans2.downans));
            res = max(res,ans2.maxx - ans3.minn);
        }
        update(1,1,n,dfn[y],dfn[x],v);
    }
    return res;
}
int main()
{
    int times; scanf("%d",×);
    while(times --)
    {
        scanf("%d",&n); tot =1;
        for(int i = 1; i <= n; i++) scanf("%d",&val[i]),g[i].clear();
        for(int i = 1; i < n; i++)
        {
            int u,v; scanf("%d%d",&u,&v);
            g[u].push_back(v);
            g[v].push_back(u);
        }
        dfs1(1,0,0);
        dfs2(1,1);
        build(1,1,n);
        scanf("%d",&m);
        for(int i = 0; i < m; i++)
        {
            int x,y,z; scanf("%d%d%d",&x,&y,&z);
            printf("%d\n",solve(x,y,z));
        }
    }
    return 0;
}

 

你可能感兴趣的:(HDU 5052(树链剖分+线段树))