链接:http://acm.hdu.edu.cn/showproblem.php?pid=5052
题意:树上每个点有一个权值对应商品价格,每次查询两个点之间的路径,从X走到Y,可以选择在一个点买入商品在另一个点卖出商品(卖出点一定要在买入点的后面),并且每次走过这条路径后,这条路径上的所有点的商品价格都会上涨V。对每个查询求出从走到Y所能获得的最大收益(不小于0,可以选择不买入)。
分析:这道题跟POJ3728(链接:http://poj.org/problem?id=3728)非常相似,但是POJ的题并没有点更新的操作,所以可以离线以后用带权并查集查出答案。但是这道题每次走过后需要对整条链上的点进行更新,所以离线是肯定不行的。但是同样可以仿照那道题的思路,用线段树来维护相应的值。
首先考虑如果是线性的一维数组,求X到Y的最大收益,就是线段树的区间合并问题。每个区间维护一个最大值,最小值,和当前区间的答案,合并区间的时候分别考虑左右子区间对于答案的贡献,以及最小值在左最大值在右的情况,就是合并后区间的答案。解决了一维数组线性情况的问题之后,考虑树上的情况。
从X走到Y可以把答案分为三种情况,设X和Y的最近公共祖先为LCA
1。在X到LCA的路上买入和卖出
2。在LCA到Y的路上买入和卖出
3.。在X到LCA的路上买入,在LCA到Y的路上卖出
和一维数组不同的地方在于,从LCA到Y用树剖做的话与题意其实是反向的,所以不仅要维护一个从下走到上的答案,还需要维护一个从上走到下的答案。然后每次将新查询出的答案与前一个合并。到X和Y走到同一条重链上的时候,分情况讨论一下X和Y的上下关系,合并方式略有区别。
树上点权更新,由于是区间更新,如果整个区间都要更新,则只会影响到区间的最大值和最小值,而答案完全不会受到影响,所以使用lazy操作,当查询或者更新时下放。
昨天写HDU4718卡了很久,今天1A感觉很开心
代码:
#include
#include
#include
#include
#define ld d<<1
#define rd d<<1|1
#define lson ld,l,m
#define rson rd,m+1,r
using namespace std;
const int N = 50000 +5;
int n,m;
vectorg[N];
int val[N];
int dfn[N],rnk[N],dep[N],fa[N],son[N],siz[N],top[N];
int tot;
void dfs1(int u,int f,int d)
{
fa[u] = f;
son[u] = -1;
siz[u] = 1;
dep[u] = d;
for(int i = 0; i < g[u].size(); i++)
{
int v = g[u][i];
if(v == f) continue;
dfs1(v,u,d+1);
siz[u] += siz[v];
if(son[u] == -1 || siz[v] > siz[son[u]]) son[u] = v;
}
return;
}
void dfs2(int u,int t)
{
dfn[u] = tot;
rnk[tot++] = u;
top[u] = t;
if(son[u] == -1) return;
dfs2(son[u],t);
for(int i = 0; i > 1;
build(lson);
build(rson);
tr[d] = merge(tr[ld],tr[rd]);
return;
}
node query(int d,int l,int r,int ql,int qr)
{
if(ql <= l && r<= qr)
{
return tr[d];
}
int m = (l + r) >> 1;
pushdown(d);
node ans1 = node(), ans2 = node();
if(m < ql) return query(rson,ql,qr);
if(qr <= m) return query(lson,ql,qr);
ans1 = query(lson,ql,qr); ans2 = query(rson,ql,qr);
return merge(ans1,ans2);
}
void update(int d,int l,int r,int ql,int qr,int v)
{
if(ql <= l && r <= qr)
{
tr[d].minn += v;
tr[d].maxx += v;
tr[d].lz += v;
return;
}
int m = (l+r) >> 1;
pushdown(d);
if(ql <= m) update(lson,ql,qr,v);
if(m dep[fy])
{
if(fl1 == 0)
{
ans1 = query(1,1,n,dfn[fx],dfn[x]);
fl1 = 1;
}
else ans1 = merge(query(1,1,n,dfn[fx],dfn[x]),ans1);
update(1,1,n,dfn[fx],dfn[x],v);
x = fa[fx], fx = top[x];
}
else
{
if(fl2 == 0)
{
ans2 = query(1,1,n,dfn[fy],dfn[y]);
fl2 = 1;
}
else ans2 = merge(query(1,1,n,dfn[fy],dfn[y]),ans2);
update(1,1,n,dfn[fy],dfn[y],v);
y = fa[fy], fy = top[y];
}
}
int res = 0;
if(dep[x] <= dep[y])
{
node ans3 = query(1,1,n,dfn[x],dfn[y]);
res = ans3.downans;
if(fl2 != 0)
{
ans3 = merge(ans3,ans2);
res = max(res,ans3.downans);
}
if(fl1 != 0)
{
res = max(res,max(ans1.upans,ans3.downans));
res = max(res,ans3.maxx - ans1.minn);
}
update(1,1,n,dfn[x],dfn[y],v);
}
else
{
node ans3 = query(1,1,n,dfn[y],dfn[x]);
if(fl1 != 0) ans3 = merge(ans3,ans1);
res = ans3.upans;
if(fl2 != 0)
{
res =max(res,max(ans3.upans,ans2.downans));
res = max(res,ans2.maxx - ans3.minn);
}
update(1,1,n,dfn[y],dfn[x],v);
}
return res;
}
int main()
{
int times; scanf("%d",×);
while(times --)
{
scanf("%d",&n); tot =1;
for(int i = 1; i <= n; i++) scanf("%d",&val[i]),g[i].clear();
for(int i = 1; i < n; i++)
{
int u,v; scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(1,0,0);
dfs2(1,1);
build(1,1,n);
scanf("%d",&m);
for(int i = 0; i < m; i++)
{
int x,y,z; scanf("%d%d%d",&x,&y,&z);
printf("%d\n",solve(x,y,z));
}
}
return 0;
}