顾名思义,就是合并两个同构(就是维护的区间长度一样)线段树,其实也没啥比较nb的算法,就是一个一个节点的合并,但是如果在n个要合并的线段树里,如果一共有m个元素,则配合动态开点,复杂度会均摊成一个惊人的 O ( m l o g n ) O(mlogn) O(mlogn)所以,在多次合并的均摊复杂度是非常优秀的.另外线段树合并还可以和线段树分裂一起构成维护一组线段树森林的方法
我们每次合并一个点,就是综合两个线段树表示相同区间的两个节点的信息,然后整合成一个,删去另一个,这时,我们可以有一个垃圾回收处理,如下:
//bac数组就是垃圾桶数组,如果里面有节点,就优先取出用掉,要是没有就另起新点
inline int newnod() {return (cnt?bac[cnt--]:++tot);}
inline void del(int p) {bac[++cnt] = p; tr[p].l = tr[p].r = tr[p].val = 0;}
合并函数可以点点进行直接合并,如果这样不方便,也可以只针对叶子节点进行直接合并,其他节点通过pushup操作得出.(总之是两个线段树所有节点都遍历一边)
例一: P4556 Vani有约会 雨天的尾巴 线段树合并模板
如题是模板题,我们讲z种不同的物资针对每个节点维护一个权值线段树(即每个节点一个).然后按照树上差分的思想,对于路径(x,y)加上z物资一件,就让x和y的权值线段树z位置加一,lca(x,y)和fa(lca(x,y))的权值z位置减一,最后dfs一边执行线段树合并,就行啦.
#include
#include
#include
#include
#include
#include
#include
#include
#define ll long long
using namespace std;
const int N = 2e5 + 5;
const int Z = 1e5 + 2;
int n, m;
int he[N], ver[N], ne[N], tot;
int d[N];
queue<int> q;
int f[N][30];
inline void add(int x, int y)
{
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
void bfs()
{
d[1] = 1;
q.push(1);
while (q.size())
{
int te = q.front();
q.pop();
for (int i = he[te]; i; i = ne[i])
{
int v = ver[i];
if (d[v]) continue;
d[v] = d[te] + 1;
f[v][0] = te;
for (int j = 1; j < 30; j++)
f[v][j] = f[f[v][j - 1]][j - 1];
q.push(v);
}
}
}
int lca(int x, int y)
{
if (d[x] > d[y]) swap(x, y);
for (int i = 29; i >= 0; i--)
{
if (d[f[y][i]] < d[x]) continue;
y = f[y][i];
}
if (x == y) return x;
for (int i = 29; i >= 0; i--)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
struct Node
{
int l, r;
int val;
int id;
}tr[N * 40];
int cnt, top;
int rt[N], bac[N*40], ans[N];
inline int newnod() { return cnt ? bac[cnt--] : ++tot; }
inline void del(int p) { bac[++cnt] = p; tr[p].l = tr[p].r = tr[p].val = 0; }
inline void pushup(int p)
{
if (tr[tr[p].l].val >= tr[tr[p].r].val)
{
tr[p].val = tr[tr[p].l].val;
tr[p].id = tr[tr[p].l].id;
}
else
{
tr[p].val = tr[tr[p].r].val;
tr[p].id = tr[tr[p].r].id;
}
}
void insert(int &p, int pos, int v, int l = 1, int r = Z)
{
if (!p) p = newnod();
if (l == r)
{
tr[p].val += v;
tr[p].id = l;
return;
}
int mid = (l + r) >> 1;
if (pos <= mid) insert(tr[p].l, pos, v, l, mid);
else insert(tr[p].r, pos, v, mid + 1, r);
pushup(p);
}
int merge(int x, int y, int l = 1, int r = Z)
{
if (!x || !y) return x + y;
int mid = (l + r) >> 1;
if (l == r)
tr[x].val += tr[y].val, tr[x].id = l;
else
{
tr[x].l = merge(tr[x].l, tr[y].l, l, mid);
tr[x].r = merge(tr[x].r, tr[y].r, mid + 1, r);
pushup(x);
}
del(y);
return x;
}
void print(int p, int l = 1, int r = Z)
{
cout << l << " " << r << " " << tr[p].val << " " << tr[p].id << endl;
if (l == r) return;
int mid = (l + r) >> 1;
print(tr[p].l, l, mid);
print(tr[p].r, mid + 1, r);
}
void dfs_mg(int cur, int fa)
{
for (int i = he[cur]; i; i = ne[i])
{
int y = ver[i];
if (y == fa) continue;
dfs_mg(y, cur);
rt[cur] = merge(rt[cur], rt[y]);
}
if (tr[rt[cur]].val)
ans[cur] = tr[rt[cur]].id;
return;
}
int main()
{
cin >> n >> m;
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
add(x, y); add(y, x);
}
bfs();
while (m--)
{
int x, y, t;
scanf("%d%d%d", &x, &y, &t);
insert(rt[x], t, 1);
insert(rt[y], t, 1);
int _lca = lca(x, y);
insert(rt[_lca], t, -1);
insert(rt[f[_lca][0]], t, -1);
}
dfs_mg(1, 0);
for (int i = 1; i <= n; i++)
printf("%d\n", ans[i]);
return 0;
}
例二: P1600 天天爱跑步
这个题写了很长时间,想了半天,才从之前做的一个题里收到启发,观察这个题,如果我们对每一个点都针对n秒维护一个权值线段树,那么,针对一个路路径,线段树向父节点合并就有两种情况:
1.路径是从子节点到父节点,这时,我们必须要让子节点线段树的所有元素都"整体向前移一位",即1秒的数量变成2秒的数量,2秒的数量变成3秒的数量…依次类推
2.路径是从父节点到子节点,这时,我们必须要让子节点线段树的所有元素都"整体向后移一位",即2秒的数量变成1秒的数量,3秒的数量变成2秒的数量…依次类推
但是这种操作很难在极短时间内进行,这时我们不如建立一个整体值 s a n san san,针对1情况,我们每上一层 s a n san san值都加一,我们往线段树中压入的是形式值,而实际值为 s a n san san+形式值,比如,在一层某个节点的 s a n san san值为6,这时我们在这里压入一个0秒,我们修改该点的权值线段树,但是不是让0位置+1,而是让 0 − s a n = − 6 0-san=-6 0−san=−6位置加一,因为-6是这一层0的形式值.这时我们往上走两层,这时 s a n san san值等于8,此时我们查询2秒的个数,这时我们其实是查2秒在这一层的形式值的个数,即 2 − s a n = − 6 2-san=-6 2−san=−6,这时我们在前插入的0秒,在这产生了影响,总而言之,我们在一个点插入形式值后,这个形式值,会依据整体值 s a n san san的不同在各层产生不同影响.然后我们必须让这课树的每一层的 s a n san san值统一,我们发现树的深度是一个比较好的天然 s a n san san值.于是乎,针对每个路径我们拆成两部分 < x , l c a >
insert(rt[x], idn(0-dep2(x)), 1);
insert(rt[f[_lca][0]], idn(0-dep2(x)), -1);//第一部分
tt = dep1[x] - dep1[_lca]) + (dep1[y] - dep1[_lca];//路径长度
insert(rt[y], idn(tt-dep1[y]), 1);
insert(rt[_lca], idn(tt-dep1[y]), -1);//第二部分
ps.这可能是迄今为止自己琢磨出的最震撼的算法了QAQ
下面是ac代码:
#include
#include
#include
#include
#include
#include
#include
#include
#define ll long long
#define max(x, y) ((x)>(y)?(x):(y))
using namespace std;
const int N = 6e5 + 5;
const int Z = 6e5 + 5;
int n, m;
int he[N], ver[N], ne[N], tot;
int dep1[N], deep;
int d[N];
int q[N], qh, ql;
int f[N][30];
inline void add(int x, int y)
{
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
inline int idn(int n) {return n + 3e5+5;}
inline int dep2(int n) {return deep - dep1[n];}
void bfs()
{
d[1] = 1;
q[++qh] = 1;
while (qh > ql)
{
int te = q[++ql];
for (int i = he[te]; i; i = ne[i])
{
int v = ver[i];
if (d[v]) continue;
d[v] = d[te] + 1;
f[v][0] = te;
for (int j = 1; j < 30; j++)
f[v][j] = f[f[v][j - 1]][j - 1];
dep1[v] = dep1[te] + 1;
deep = max(dep1[v], deep);
q[++qh] = v;
}
}
}
int lca(int x, int y)
{
if (d[x] > d[y]) swap(x, y);
for (int i = 29; i >= 0; i--)
{
if (d[f[y][i]] < d[x]) continue;
y = f[y][i];
}
if (x == y) return x;
for (int i = 29; i >= 0; i--)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
struct Node
{
int l, r;
int val;
}tr[N * 40];
int cnt, top;
int rt[N], bac[N*40];
inline int newnod() { return cnt ? bac[cnt--] : ++top; }
inline void del(int p) { bac[++cnt] = p; tr[p].l = tr[p].r = tr[p].val = 0; }
void insert(int &p, int pos, int v, int l = 1, int r = Z)
{
if (!p) p = newnod();
tr[p].val += v;
if (l == r) return;
int mid = (l + r) >> 1;
if (pos <= mid) insert(tr[p].l, pos, v, l, mid);
else insert(tr[p].r, pos, v, mid + 1, r);
}
int merge(int x, int y, int l = 1, int r = Z)
{
if (!x || !y) return x + y;
tr[x].val += tr[y].val;
int mid = (l + r) >> 1;
tr[x].l = merge(tr[x].l, tr[y].l, l, mid);
tr[x].r = merge(tr[x].r, tr[y].r, mid + 1, r);
del(y);
return x;
}
int qt[N], ans[N];
bool flag = 0;
int ask(int p, int k, int l = 1, int r = Z)
{
if (l == r) return tr[p].val;
int mid = (l + r) >> 1;
if (k <= mid) return ask(tr[p].l, k, l, mid);
else return ask(tr[p].r, k, mid+1, r);
}
void dfs_mg(int cur, int fa)
{
for (int i = he[cur]; i; i = ne[i])
{
int y = ver[i];
if (y == fa) continue;
dfs_mg(y, cur);
rt[cur] = merge(rt[cur], rt[y]);
}
int tt = qt[cur];
ans[cur] += ask(rt[cur], idn(tt - (flag ? dep1[cur] : dep2(cur))));
return;
}
struct qy
{
int lca, y;
int lca_m, y_m;
}qr[N];
int main()
{
cin >> n >> m;
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
add(x, y); add(y, x);
}
bfs();
for (int i = 1; i <= n; i++) scanf("%d", &qt[i]);
for (int i = 1; i <= m; i++)
{
int x, y;
scanf("%d%d", &x, &y);
int _lca = lca(x, y);
insert(rt[x], idn(0-dep2(x)), 1);
insert(rt[f[_lca][0]], idn(0-dep2(x)), -1);
qr[i].lca = _lca; qr[i].y = y; qr[i].y_m = (dep1[x] - dep1[_lca]) + (dep1[y] - dep1[_lca]);
}
dfs_mg(1, 0);
cnt = top = 0;
for (int i = 0; i < N * 40; i++) tr[i].l = tr[i].r = tr[i].val = 0;
memset(rt, 0, sizeof(rt));
for (int i = 1; i <= m; i++)
{
int y = qr[i].y, _lca = qr[i].lca;
int tt = qr[i].y_m;
insert(rt[y], idn(tt-dep1[y]), 1);
insert(rt[_lca], idn(tt-dep1[y]), -1);
}
flag = 1;
dfs_mg(1, 0);
for (int i = 1; i <= n; i++)
printf("%d ", ans[i]);
puts("");
return 0;
}
例三:P3224 [HNOI2012]永无乡
一个模板中的模板了.结合并查集找根即可.
下面是ac代码:
#include
#include
#include
#include
#include
#include
#include
#include
#define ll long long
using namespace std;
const int N = 2e5 + 5;
int fa[N];
int fi(int x)
{
if (x == fa[x]) return x;
return fa[x] = fi(fa[x]);
}
struct Node
{
int l, r;
int val;
}tr[N * 40];
int cnt, top;
int rt[N], bac[N*40];
inline int newnod() { return cnt ? bac[cnt--] : ++top; }
inline void del(int p) { bac[++cnt] = p; tr[p].l = tr[p].r = tr[p].val = 0; }
int n, m;
void insert(int &p, int pos, int v, int l = 1, int r = n)
{
if (!p) p = newnod();
tr[p].val += v;
if (l == r) return;
int mid = (l + r) >> 1;
if (pos <= mid) insert(tr[p].l, pos, v, l, mid);
else insert(tr[p].r, pos, v, mid + 1, r);
}
int merge(int x, int y, int l = 1, int r = n)
{
if (!x || !y) return x + y;
int mid = (l + r) >> 1;
tr[x].val += tr[y].val;
tr[x].l = merge(tr[x].l, tr[y].l, l, mid);
tr[x].r = merge(tr[x].r, tr[y].r, mid + 1, r);
del(y);
return x;
}
void mmerge(int x, int y)
{
fa[y] = x;
rt[x] = merge(rt[x], rt[y]);
}
int ask(int p, int k, int l = 1, int r = n)
{
if (l == r) return l;
int mid = (l + r) >> 1;
if (tr[tr[p].l].val >= k) return ask(tr[p].l, k, l, mid);
else return ask(tr[p].r, k - tr[tr[p].l].val, mid+1, r);
}
void Debug_print(int p, int l = 1, int r = n)
{
cout << l << " " << r << " " << tr[p].val << endl;
if (l == r) { return;}
int mid = (l + r) >> 1;
Debug_print(tr[p].l, l, mid);
Debug_print(tr[p].r, mid + 1, r);
}
void De()
{
for (int i = 1; i <= n; i++)
cout << fa[i] << " ";
cout << endl;
for (int i = 1; i <= n; i++)
{
cout << "--------------" << endl;
cout << i << endl;
for (int j = 1; j <= n; j++)
Debug_print(rt[i], j);
cout << "\n--------------" << endl;
}
}
int _rank[N];
int main()
{
cin >> n >> m;
for (int i = 1; i <= n; i++)
{
fa[i] = i;
int te; scanf("%d", &te);
insert(rt[i], te, 1);
_rank[te] = i;
}
while(m--)
{
int x, y;
scanf("%d%d", &x, &y);
x = fi(x); y = fi(y);
if (x != y) mmerge(x, y);
}
int q; cin >> q;
while(q--)
{
char op[5];
int x, y;
scanf("%s%d%d", op, &x, &y);
if (op[0] == 'Q')
{
x = fi(x);
if (tr[rt[x]].val < y) {puts("-1"); continue;}
int te = ask(rt[x], y);
printf("%d\n", _rank[te]);
}
else
{
x = fi(x); y = fi(y);
if (x != y) mmerge(x, y);
}
}
return 0;
}
补充一例:CF600E Lomsat gelral
这个题可以用dsu来做,但是今天发现可以用线段树合并也可以!啊啊啊,这些神奇的算法还是这么的神奇,个人觉得线段树合并比dsu还要好理解一点.
下面是ac代码:
#include
#include
#include
#define ll long long
#define int ll
using namespace std;
const int N = 1e5+5;
int top, cnt, tot;
int bac[N*40], ver[N<<1], he[N], ne[N<<1], rt[N];
int su[N];
int n;
struct Node
{
int l, r; ll val, ans;
}tr[N*40];
inline int neww() { return cnt ? bac[cnt--] : ++top; }
inline void del(int p) { bac[++cnt]; tr[p].l = tr[p].r = tr[p].val = 0; }
void add(int x, int y)
{
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
inline void pushup(int p)
{
if (tr[tr[p].l].val > tr[tr[p].r].val)
{
tr[p].val = tr[tr[p].l].val;
tr[p].ans = tr[tr[p].l].ans;
}
else if (tr[tr[p].l].val < tr[tr[p].r].val)
{
tr[p].val = tr[tr[p].r].val;
tr[p].ans = tr[tr[p].r].ans;
}
else
{
tr[p].val = tr[tr[p].l].val;
tr[p].ans = tr[tr[p].l].ans + tr[tr[p].r].ans;
}
}
void ins(int &p, int k, int val, int l = 1, int r = n)
{
if (!p) p = neww();
if (l == r){tr[p].val += val; tr[p].ans = l; return;}
int mid = (l + r) >> 1;
if (k <= mid) ins(tr[p].l, k, val, l, mid);
else ins(tr[p].r, k, val, mid+1, r);
pushup(p);
}
int merge(int x, int y, int l = 1, int r = n)
{
if (!x || !y) return x + y;
if (l == r)
{tr[x].val += tr[y].val; tr[x].ans = l;}
else
{
int mid = (l + r) >> 1;
tr[x].l = merge(tr[x].l, tr[y].l, l, mid);
tr[x].r = merge(tr[x].r, tr[y].r, mid+1, r);
pushup(x);
}
del(y);
return x;
}
int ans[N];
void dfs_mg(int cur, int fa)
{
for (int i = he[cur]; i; i = ne[i])
{
int y = ver[i];
if (y == fa) continue;
dfs_mg(y, cur);
rt[cur] = merge(rt[cur], rt[y]);
}
ans[cur] = tr[rt[cur]].ans;
return;
}
signed main()
{
cin >> n;
for (int i = 1; i <= n; i++)
{
scanf("%lld", &su[i]);
ins(rt[i], su[i], 1);
}
for (int i = 1; i < n; i++)
{
int x, y; scanf("%lld%lld", &x, &y);
add(x, y); add(y, x);
}
dfs_mg(1, 0);
for (int i = 1; i <= n; i++)
printf("%lld ", ans[i]);
puts("");
return 0;
}