POJ 3237(树链剖分 + 线段树)

题目链接:我是链接!

题意:在一棵树上,点和点之间的路径有边权,有三种操作 CHANGE把第i条边的权值改为a,NEGATE把第a个点到第b个点的边的权值全部取相反数,QUERY求第a个点到第b个点的最大边权。

分析:树链剖分的模板题了,之前写的用树状数组维护就可以了,这次稍微复杂一点用线段树+懒操作来维护。但是考虑到需要加lazy标记,所以除了区间最大值以外还需要维护一个区间最小值。然后每次对区间内的数全部取相反数,应该是lazy ^= 1,这个地方写错了wa了一发(因为取反再取反就不用下传了嘛,还是很简单的)。

代码:

#include 
#include 
#include 
#include 
#include 
#include 
#define ll long long
#define inf 0x3f3f3f3f
#define ld d<<1
#define rd d<<1|1
using namespace std;
const int N = 10000 + 5;
typedef pairpii;
pii mp[N];
int cos[N];
int fa[N],dep[N],son[N],siz[N],dfn[N],top[N],tot;
int n;
vectorg[N];
void dfs1(int x,int f,int d)
{
    fa[x] = f;
    dep[x] = d;
    son[x] = -1;
    siz[x] = 1;
    for(int i = 0; i < g[x].size(); i++)
    {
        int v = g[x][i];
        if(v == f) continue;
        dfs1(v,x,d+1);
        if(son[x] == -1 || siz[v] > siz[son[x]]) son[x] = v;
        siz[x] += siz[v];
    }
    return;
}
void dfs2(int x,int t)
{
    top[x] = t;
    dfn[x] = tot++;
    if(son[x] == -1) return;
    dfs2(son[x],t);
    for(int i = 0; i < g[x].size(); i++)
    {
        int v = g[x][i];
        if(v == fa[x] || v == son[x]) continue;
        dfs2(v,v);
    }
    return;
}
struct node
{
    int maxx,minn,lazy;
}tr[N<<2];
void build(int d,int l,int r)
{
    tr[d].lazy = 0;
    if(l == r) return;
    int m = (l+r) >>1;
    build(ld,l,m);
    build(rd,m+1,r);
}
void pushup(int d)
{
        tr[d].maxx = max(tr[ld].maxx,tr[rd].maxx);
    tr[d].minn = min(tr[ld].minn,tr[rd].minn);
    return;
}
void pushdown(int d)
{
    if(tr[d].lazy == 0) return;
    int m1 = tr[ld].maxx, m2 = tr[ld].minn;
    tr[ld].maxx = max(-m1,-m2);
    tr[ld].minn = min(-m1,-m2);
    tr[ld].lazy ^= 1;
    m1 = tr[rd].maxx, m2 = tr[rd].minn;
    tr[rd].maxx = max(-m1,-m2);
    tr[rd].minn = min(-m1,-m2);
    tr[rd].lazy ^= 1;
    tr[d].lazy ^= 1;
}
void updatep(int d,int l,int r,int p,int v)
{
    if(l == r && r == p)
    {
        tr[d].maxx = v;
        tr[d].minn = v;
        return;
    }
    pushdown(d);
    int m = (l+r) >> 1;
    if(p <= m) updatep(ld,l,m,p,v);
    else updatep(rd,m+1,r,p,v);
    pushup(d);
    return;
}
void updateseg(int d,int l,int r,int ql,int qr)
{
    if(ql <= l && r <= qr)
    {
        int m1 = tr[d].maxx ,  m2 = tr[d].minn;
        tr[d].maxx = max(-m1,-m2);
        tr[d].minn = min(-m1,-m2);
        tr[d].lazy = tr[d].lazy^1;
        return;
    }
    pushdown(d);
    int m = (l+r) >> 1;
    if(ql <= m) updateseg(ld,l,m,ql,qr);
    if(m < qr) updateseg(rd,m+1,r,ql,qr);
    pushup(d);
    return;
}
int query(int d,int l,int r,int ql,int qr)
{

    if(ql <= l && r <= qr)
    {
        return tr[d].maxx;
    }
    pushdown(d);
    int m = (l+r) >> 1;
    int ans = -inf;
    if(ql <=  m) ans = max(ans,query(ld,l,m,ql,qr));
    if(m < qr) ans = max(ans,query(rd,m+1,r,ql,qr));
    return ans;
}
void solven(int x,int y)
{
    int fx = top[x], fy = top[y];
    while(fx != fy)
    {
        if(dep[fx] < dep[fy]) swap(fx,fy), swap(x,y);
        updateseg(1,1,n,dfn[fx],dfn[x]);
        x = fa[fx];
        fx = top[x];
    }
    if(x != y)
    {
        if(dep[x] > dep[y]) swap(x,y);
        updateseg(1,1,n,dfn[x]+1,dfn[y]);
    }
    return;
}
int solvem(int x,int y)
{
    int fx = top[x], fy = top[y];
    int ans = -inf;
    while(fx != fy)
    {
        if(dep[fx] < dep[fy]) swap(fx,fy), swap(x,y);
        ans = max(ans,query(1,1,n,dfn[fx],dfn[x]));
        x = fa[fx];
        fx = top[x];
    }
    if(x != y)
    {
        if(dep[x] > dep[y]) swap(x,y);
        ans = max(ans,query(1,1,n,dfn[x]+1,dfn[y]));
    }
    return ans;
}
int main()
{
    int times;
    scanf("%d",×);
    while(times --)
    {
        scanf("%d",&n);
        for(int i = 1; i <= n; i++) g[i].clear();
        tot = 1;
        for(int i = 1; i < n; i++)
        {
            int u,v,w; scanf("%d%d%d",&u,&v,&w);
            g[u].push_back(v);
            g[v].push_back(u);
            mp[i] = make_pair(u,v);
            cos[i] = w;
        }
        dfs1(1,0,0);
        dfs2(1,1);
        build(1,1,n);
        for(int i = 1; i < n; i++)
        {
            int u = mp[i].first, v = mp[i].second;
            if(dep[u] < dep[v]) swap(u,v), swap(mp[i].first, mp[i].second);
            updatep(1,1,n,dfn[u],cos[i]);
        }
        char op[10];
        while(~scanf(" %s",op))
        {
            int a,b;
            if(op[0] == 'D') break;
            scanf("%d%d",&a,&b);
            if(op[0] == 'C')
            {
                int u = mp[a].first;
                updatep(1,1,n,dfn[u],b);
            }
            else if(op[0] == 'N')   solven(a,b);
            else   printf("%d\n",solvem(a,b));
        }
    }
}

 

你可能感兴趣的:(POJ 3237(树链剖分 + 线段树))