P2486 [SDOI2011]染色 (树剖+线段树)

传送门

题意

给定一棵 n 个节点的无根树,共有 m 个操作,操作分为两种:
1.将节点 a 到节点 b 的路径上的所有点(包括 a 和 b)都染成颜色 c。
2.询问节点 a 到节点 b 的路径上的颜色段数量。
颜色段的定义是极长的连续相同颜色被认为是一段。例如 112221 由三段组成:11、222、1。

分析

树上路径问题,首先考虑树剖。
用线段树维护区间颜色段信息
线段树的维护的信息&&基本操作

struct node
{
    int l, r, lz;//lz为懒标记
    int v;//颜色段数量
    int lv, rv;//该结点的左边界、右边界是什么颜色
} tr[N << 2];
void pushup(int u)
{
    if (tr[u << 1].rv == tr[u << 1 | 1].lv)
        tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v - 1;
    else
        tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v;
    tr[u].lv=tr[u<<1].lv;
    tr[u].rv=tr[u<<1|1].rv;
}
void pushdown(int u)
{
    if (tr[u].lz)
    {
        tr[u << 1].v = 1;
        tr[u << 1].lv = tr[u << 1].rv = tr[u].lz;
        tr[u << 1].lz = tr[u].lz;
        tr[u << 1 | 1].v = 1;
        tr[u << 1 | 1].lv = tr[u << 1 | 1].rv = tr[u].lz;
        tr[u << 1 | 1].lz = tr[u].lz;
        tr[u].lz = 0;
    }
}

继续考虑,这颗线段树的区间修改操作是常规、没有问题的。
但是!在区间查询的时候,有很多细节需要注意:
1.在一条重链上查询时,需要注意边界
P2486 [SDOI2011]染色 (树剖+线段树)_第1张图片
2.跨越两条重链查询时,也要处理边界问题。
这里我们需要记录上一次合并的边界信息(luu,lvv)。
P2486 [SDOI2011]染色 (树剖+线段树)_第2张图片

代码

#include 

using namespace std;
//-----pre_def----
const double PI = acos(-1.0);
const int INF = 0x3f3f3f3f;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
#define fir(i, a, b) for (int i = (a); i <= (b); i++)
#define rif(i, a, b) for (int i = (a); i >= (b); i--)
#define endl '\n'
#define init_h memset(h, -1, sizeof h), idx = 0;
#define lowbit(x) x &(-x)

//---------------
const int N = 1e5 + 10;
int n, m;
int h[N], e[N << 1], ne[N << 1], idx;
int lans, rans, ans;
int d[N], id[N], nd[N], fa[N], dep[N], top[N], son[N], cnt, sz[N];
int Lv, Rv;
struct node
{
    int l, r, lz;
    int v;
    int lv, rv;
} tr[N << 2];
void add(int a, int b)
{
    e[idx] = b;
    ne[idx] = h[a];
    h[a] = idx++;
}
void dfs1(int u, int father, int depth)
{
    sz[u] = 1, fa[u] = father, dep[u] = depth;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int t = e[i];
        if (t == fa[u])
            continue;
        dfs1(t, u, depth + 1);
        sz[u] += sz[t];
        if (sz[son[u]] < sz[t])
        {
            son[u] = t;
        }
    }
}
void dfs2(int u, int father)
{
    id[u] = ++cnt;
    nd[cnt] = d[u];
    top[u] = father;
    if (!son[u])
        return;
    dfs2(son[u], father);
    for (int i = h[u]; ~i; i = ne[i])
    {
        int t = e[i];
        if (t == fa[u] || t == son[u])
            continue;
        dfs2(t, t);
    }
}
void pushup(int u)
{
    if (tr[u << 1].rv == tr[u << 1 | 1].lv)
    {
        tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v - 1;
    }
    else
    {
        tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v;
    }
    tr[u].lv=tr[u<<1].lv;
    tr[u].rv=tr[u<<1|1].rv;
}
void pushdown(int u)
{
    if (tr[u].lz)
    {
        tr[u << 1].v = 1;
        tr[u << 1].lv = tr[u << 1].rv = tr[u].lz;
        tr[u << 1].lz = tr[u].lz;
        tr[u << 1 | 1].v = 1;
        tr[u << 1 | 1].lv = tr[u << 1 | 1].rv = tr[u].lz;
        tr[u << 1 | 1].lz = tr[u].lz;
        tr[u].lz = 0;
    }
}
void build(int u, int l, int r)
{
    tr[u] = {l, r, 0, 0, 0, 0};
    if (l == r)
    {
        tr[u] = {l, r, 0, 1, nd[l], nd[r]};
        return;
    }
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    pushup(u);
}
void modify(int u, int l, int r, int k)
{
    if (l <= tr[u].l && tr[u].r <= r)
    {
        tr[u].v = 1;
        tr[u].lv = tr[u].rv = k;
        tr[u].lz = k;
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid)
        modify(u << 1, l, r, k);
    if (mid < r)
        modify(u << 1 | 1, l, r, k);
    pushup(u);
}
node query(int u, int l, int r)
{
    if (l <= tr[u].l && tr[u].r <= r)
    {
        return tr[u];
    }
    int res = 0;
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid)
    {
        if (mid < r)
        {
            node res, ltr = query(u<<1, l, r), rtr = query(u<<1|1, l, r);
            res.v = ltr.v + rtr.v + (ltr.rv == rtr.lv ? -1 : 0);
            res.lv = ltr.lv, res.rv = rtr.rv;
            return res;
        }
        else
        {
            return query(u << 1, l, r);
        }
    }
    else
        return query(u << 1 | 1, l, r);
}
int query_path(int u, int v)
{
    //node res = {0, 0, 0, 0, 0, 0};
    int res = 0, luu = 0, lvv = 0;//答案,上一次u、v节点合并区间的端点颜色信息
    while (top[u] != top[v])
    {
        if (dep[top[u]] < dep[top[v]])
            swap(u, v), swap(luu, lvv);
        node tmp = query(1, id[top[u]], id[u]);
        res += tmp.v;
        if (luu == tmp.rv)
            res--;
        luu = tmp.lv;
        u = fa[top[u]];
    }
    if (dep[u] > dep[v])
        swap(u, v), swap(luu, lvv);
    node tmp = query(1, id[u], id[v]);
    res += tmp.v;
    if (tmp.lv == luu)
        res--;
    if (tmp.rv == lvv)
        res--;
    return res;
}

void modify_path(int u, int v, int k)
{
    while (top[u] != top[v])
    {
        if (dep[top[u]] < dep[top[v]])
            swap(u, v);
        modify(1, id[top[u]], id[u], k);
        u = fa[top[u]];
    }
    if (dep[u] < dep[v])
        swap(u, v);
    modify(1, id[v], id[u], k);
}

void init() {}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    int StartTime = clock();
#endif
    scanf("%d%d", &n, &m);
    fir(i, 1, n) scanf("%d", &d[i]);
    init_h;
    fir(i, 2, n)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }
    dfs1(1, -1, 1);
    dfs2(1, 1);
    build(1, 1, n);
    while (m--)
    {
        char op[3];
        int a, b, c;
        scanf("%s", op);
        if (*op == 'Q')
        {
            scanf("%d%d", &a, &b);
            printf("%d\n", query_path(a, b));
        }
        else
        {
            scanf("%d%d%d", &a, &b, &c);
            modify_path(a, b, c);
        }
    }
#ifndef ONLINE_JUDGE
    printf("Run_Time = %d ms\n", clock() - StartTime);
#endif
    return 0;
}

总结

1.一条重链的序号是从小到大的,因此可以判断哪头是线段树的l,哪头是线段树的r。
2.如何在重链之间维护查询信息。

你可能感兴趣的:(树链剖分,树剖)