传送门
给定一棵 n 个节点的无根树,共有 m 个操作,操作分为两种:
1.将节点 a 到节点 b 的路径上的所有点(包括 a 和 b)都染成颜色 c。
2.询问节点 a 到节点 b 的路径上的颜色段数量。
颜色段的定义是极长的连续相同颜色被认为是一段。例如 112221 由三段组成:11、222、1。
树上路径问题,首先考虑树剖。
用线段树维护区间颜色段信息
线段树的维护的信息&&基本操作
struct node
{
int l, r, lz;//lz为懒标记
int v;//颜色段数量
int lv, rv;//该结点的左边界、右边界是什么颜色
} tr[N << 2];
void pushup(int u)
{
if (tr[u << 1].rv == tr[u << 1 | 1].lv)
tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v - 1;
else
tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v;
tr[u].lv=tr[u<<1].lv;
tr[u].rv=tr[u<<1|1].rv;
}
void pushdown(int u)
{
if (tr[u].lz)
{
tr[u << 1].v = 1;
tr[u << 1].lv = tr[u << 1].rv = tr[u].lz;
tr[u << 1].lz = tr[u].lz;
tr[u << 1 | 1].v = 1;
tr[u << 1 | 1].lv = tr[u << 1 | 1].rv = tr[u].lz;
tr[u << 1 | 1].lz = tr[u].lz;
tr[u].lz = 0;
}
}
继续考虑,这颗线段树的区间修改操作是常规、没有问题的。
但是!在区间查询的时候,有很多细节需要注意:
1.在一条重链上查询时,需要注意边界
2.跨越两条重链查询时,也要处理边界问题。
这里我们需要记录上一次合并的边界信息(luu,lvv)。
#include
using namespace std;
//-----pre_def----
const double PI = acos(-1.0);
const int INF = 0x3f3f3f3f;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
#define fir(i, a, b) for (int i = (a); i <= (b); i++)
#define rif(i, a, b) for (int i = (a); i >= (b); i--)
#define endl '\n'
#define init_h memset(h, -1, sizeof h), idx = 0;
#define lowbit(x) x &(-x)
//---------------
const int N = 1e5 + 10;
int n, m;
int h[N], e[N << 1], ne[N << 1], idx;
int lans, rans, ans;
int d[N], id[N], nd[N], fa[N], dep[N], top[N], son[N], cnt, sz[N];
int Lv, Rv;
struct node
{
int l, r, lz;
int v;
int lv, rv;
} tr[N << 2];
void add(int a, int b)
{
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
void dfs1(int u, int father, int depth)
{
sz[u] = 1, fa[u] = father, dep[u] = depth;
for (int i = h[u]; ~i; i = ne[i])
{
int t = e[i];
if (t == fa[u])
continue;
dfs1(t, u, depth + 1);
sz[u] += sz[t];
if (sz[son[u]] < sz[t])
{
son[u] = t;
}
}
}
void dfs2(int u, int father)
{
id[u] = ++cnt;
nd[cnt] = d[u];
top[u] = father;
if (!son[u])
return;
dfs2(son[u], father);
for (int i = h[u]; ~i; i = ne[i])
{
int t = e[i];
if (t == fa[u] || t == son[u])
continue;
dfs2(t, t);
}
}
void pushup(int u)
{
if (tr[u << 1].rv == tr[u << 1 | 1].lv)
{
tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v - 1;
}
else
{
tr[u].v = tr[u << 1].v + tr[u << 1 | 1].v;
}
tr[u].lv=tr[u<<1].lv;
tr[u].rv=tr[u<<1|1].rv;
}
void pushdown(int u)
{
if (tr[u].lz)
{
tr[u << 1].v = 1;
tr[u << 1].lv = tr[u << 1].rv = tr[u].lz;
tr[u << 1].lz = tr[u].lz;
tr[u << 1 | 1].v = 1;
tr[u << 1 | 1].lv = tr[u << 1 | 1].rv = tr[u].lz;
tr[u << 1 | 1].lz = tr[u].lz;
tr[u].lz = 0;
}
}
void build(int u, int l, int r)
{
tr[u] = {l, r, 0, 0, 0, 0};
if (l == r)
{
tr[u] = {l, r, 0, 1, nd[l], nd[r]};
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r, int k)
{
if (l <= tr[u].l && tr[u].r <= r)
{
tr[u].v = 1;
tr[u].lv = tr[u].rv = k;
tr[u].lz = k;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid)
modify(u << 1, l, r, k);
if (mid < r)
modify(u << 1 | 1, l, r, k);
pushup(u);
}
node query(int u, int l, int r)
{
if (l <= tr[u].l && tr[u].r <= r)
{
return tr[u];
}
int res = 0;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid)
{
if (mid < r)
{
node res, ltr = query(u<<1, l, r), rtr = query(u<<1|1, l, r);
res.v = ltr.v + rtr.v + (ltr.rv == rtr.lv ? -1 : 0);
res.lv = ltr.lv, res.rv = rtr.rv;
return res;
}
else
{
return query(u << 1, l, r);
}
}
else
return query(u << 1 | 1, l, r);
}
int query_path(int u, int v)
{
//node res = {0, 0, 0, 0, 0, 0};
int res = 0, luu = 0, lvv = 0;//答案,上一次u、v节点合并区间的端点颜色信息
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])
swap(u, v), swap(luu, lvv);
node tmp = query(1, id[top[u]], id[u]);
res += tmp.v;
if (luu == tmp.rv)
res--;
luu = tmp.lv;
u = fa[top[u]];
}
if (dep[u] > dep[v])
swap(u, v), swap(luu, lvv);
node tmp = query(1, id[u], id[v]);
res += tmp.v;
if (tmp.lv == luu)
res--;
if (tmp.rv == lvv)
res--;
return res;
}
void modify_path(int u, int v, int k)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])
swap(u, v);
modify(1, id[top[u]], id[u], k);
u = fa[top[u]];
}
if (dep[u] < dep[v])
swap(u, v);
modify(1, id[v], id[u], k);
}
void init() {}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
int StartTime = clock();
#endif
scanf("%d%d", &n, &m);
fir(i, 1, n) scanf("%d", &d[i]);
init_h;
fir(i, 2, n)
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b);
add(b, a);
}
dfs1(1, -1, 1);
dfs2(1, 1);
build(1, 1, n);
while (m--)
{
char op[3];
int a, b, c;
scanf("%s", op);
if (*op == 'Q')
{
scanf("%d%d", &a, &b);
printf("%d\n", query_path(a, b));
}
else
{
scanf("%d%d%d", &a, &b, &c);
modify_path(a, b, c);
}
}
#ifndef ONLINE_JUDGE
printf("Run_Time = %d ms\n", clock() - StartTime);
#endif
return 0;
}
1.一条重链的序号是从小到大的,因此可以判断哪头是线段树的l,哪头是线段树的r。
2.如何在重链之间维护查询信息。