【学习笔记】线段树的扩展(线段树的合并与分裂、可持久化线段树)

  • 感觉最近研究数据结构,我的对拍能力和输出调试能力得到了显著提升……
  • 本篇文章介绍关于线段树的一些经典扩展操作。
  • 有关线段树的经典问题(势能线段树、李超线段树、线段树维护单调子序列)的总结请看神仙 x y z 32768 xyz32768 xyz32768 的这篇文章:[学习笔记]线段树骚操作选讲

1. 线段树的合并与分裂

1.1 BZOJ2212:[POI2011]Tree Rotations

  • 题目来源:BZOJ2212

题目大意:给定一棵二叉树,每个结点有两个儿子或没有儿子, n n n 个叶子结点的权值为 1 1 1 n n n 的排列。现在我们可以任意交换某些结点的左子树和右子树,要求进行一系列交换,使得最终所有叶子节点的权值按照遍历序写出来,逆序对个数最少。 n ≤ 2 × 1 0 5 n\le 2\times 10^5 n2×105

  • 考虑这道经典问题。
  • 我们知道对于一个结点,以它为根的子树的逆序对可以划分为两种:
    1. 逆序对的两个数同时属于左子树或同时属于右子树
    2. 逆序对的两个数分别属于左右子树
  • 现在我们对于这个结点有两种决策:交换孩子和不交换孩子
  • 这两个决策不影响其祖先的逆序对计算,也不会影响第一种逆序对个数,仅会影响到第二种。
  • 对于第一种我们可以递归到孩子的时候计算。
  • 对于第二种,一个直观的想法是分别计算两种决策的答案,选取更优的那个。
  • 我们可以暴力对于每个结点维护出一个升序序列表示这个结点对应子树的权值。然后归并排序计算两种的逆序对。
  • 但是这么做时间复杂度是 O ( n 2 ) O(n^2) O(n2) 的。
  • 我们可以对于每个结点维护一个动态开点的权值线段树,表示出现过的权值,然后同时遍历一遍权值线段树同样可以做到 O ( n 2 ) O(n^2) O(n2) 的时间复杂度。
  • 但是明显我们是可以优化这种做法的。
  • 一种优化是启发式合并,即我们直接把包含结点数小的那棵线段树中的所有结点取出来,一个个插入到另一棵线段树,在此过程中计算贡献。因为每个结点被插入一次,它所在的线段树的结点数就会至少扩大一倍,每次插入几个结点是 O ( log ⁡ n ) O(\log n) O(logn) 的,所以总时间复杂度是 O ( n log ⁡ 2 n ) O(n\log^2 n) O(nlog2n)。实际上,可以做到更优,也就是另一种做法。

  • 另一种很显然的优化是我们遍历到某个结点,若两棵线段树中这个结点有一棵的对应位置是空的,则没必要遍历下去。这么做实际上是 O ( n log ⁡ n ) O(n\log n) O(nlogn) 的,下面会证明,实际上就是我们常用的线段树合并的做法,保证这个时间复杂度也有一个条件,下面也会讲到。
  • 这么做的代码是这样的(因为被合并的线段树可以直接删除,这里用到了空间回收):
#include 

typedef long long s64; 

inline char nextChar()
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer + buffer_size; 
	
	if (head == tail)
	{
		fread(buffer, 1, buffer_size, stdin); 
		head = buffer; 
	}
	return *head++; 
}

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = nextChar())); 
	x = ch - '0'; 
	while (isdigit(ch = nextChar()))
		x = x * 10 + ch - '0'; 
}

template <class T>
inline void relax(T &x, const T &y)
{
	if (x < y)
		x = y; 
}

const int MaxN = 2e5 + 5; 
const int MaxS = MaxN * 10; 

int n, sze_seg, top, stk[MaxS]; 
int lc[MaxS], rc[MaxS], sze[MaxS]; 

s64 res1, res2, ans; 

inline void del(int x)
{
	lc[x] = rc[x] = sze[x] = 0; 
	stk[++top] = x; 
}

inline int get_new()
{
	return top ? stk[--top + 1]: ++sze_seg; 
}

inline void insert(int &x, int l, int r, int pos)
{
	if (!x) x = get_new();
	++sze[x]; 
	
	if (l == r) return; 
	int mid = l + r >> 1; 
	if (pos <= mid)
		insert(lc[x], l, mid, pos); 
	else
		insert(rc[x], mid + 1, r, pos); 
}
inline int merge(int x, int y, int l, int r)
{
	if (!x || !y) return x + y; 
	
	res1 += 1LL * sze[lc[x]] * sze[rc[y]]; 
	res2 += 1LL * sze[rc[x]] * sze[lc[y]]; 
	
	int mid = l + r >> 1; 
	lc[x] = merge(lc[x], lc[y], l, mid); 
	rc[x] = merge(rc[x], rc[y], mid + 1, r); 
	
	return sze[x] += sze[y], del(y), x; 
}

inline int solve()
{
	int x; 
	read(x); 
	if (!x)
	{
		int ch_l = solve(); 
		int ch_r = solve(); 
		res1 = res2 = 0; 
		
		int res = merge(ch_l, ch_r, 1, n); 
		ans += std::min(res1, res2); 
		return res; 
	}
	else
	{
		int res = 0; 
		insert(res, 1, n, x); 
		return res; 
	}
}

int main()
{
	read(n); 
	solve(); 
	std::cout << ans << std::endl; 
	return 0; 
} 
  • 这么做,相当于开始有 n n n 个只插入了一个点的权值线段树,按任意顺序合并,最后合并成一个 n n n 个点的权值线段树。问最坏的时间复杂度是多少。
  • 显然,每次合并的时间就是两棵线段树重合的结点数。
  • 考虑 n n n 个权值线段树的总结点数是 O ( n log ⁡ n ) O(n\log n) O(nlogn) 的,每次合并重合部分后,相当于删去了其中一棵树的那部分结点,且删去每个点是 O ( 1 ) O(1) O(1) 的。
  • 所以时间复杂度就是 O ( n log ⁡ n ) O(n\log n) O(nlogn)

1.2 BZOJ3545:[ONTAK2010]Peaks

  • 题目来源:BZOJ3545

题目大意:有 N N N 座山峰,每座山峰有他的高度 h i h_i hi。有些山峰之间有双向道路相连,共 M M M 条路径,每条路径有一个困难值,这个值越大表示越难走,现在有 Q Q Q 组询问,每组询问询问从点 v v v 开始只经过困难值小于等于 x x x 的路径所能到达的山峰中第 k k k 高的山峰的高度,如果无解输出 − 1 -1 1。允许离线。
N ≤ 1 0 5 ,   Q , M ≤ 5 × 1 0 5 ,   x , h i ≤ 1 0 9 N\le 10^5,~Q,M\le 5\times10^5,~x,h_i\le10^9 N105, Q,M5×105, x,hi109

  • 把询问离线,按 x x x 排序,可以用并查集维护可以到达的点集。
  • 对于每个集合,我们维护一个权值线段树,查询 k k k 大只需要在权值线段树上二分即可,也是经典套路。
  • 然后合并集合的时候用线段树合并的套路即可。
  • 时间复杂度 O ( ( n + Q ) log ⁡ n ) O((n+Q)\log n) O((n+Q)logn)
  • 关于这道题的在线版本,可以使用 K r u s k a l Kruskal Kruskal 重构树,这里不赘述,我后面还会写一篇介绍 K r u s k a l Kruskal Kruskal 重构树。
#include 

inline char nextChar()
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer + buffer_size; 
	
	if (head == tail)
	{
		fread(buffer, 1, buffer_size, stdin); 
		head = buffer; 
	}
	return *head++; 
}

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = nextChar())); 
	x = ch - '0'; 
	while (isdigit(ch = nextChar()))
		x = x * 10 + ch - '0'; 
}

inline void putChar(char ch)
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer; 
	
	if (ch == '\0')
		fwrite(buffer, 1, head - buffer, stdout); 
	
	*head++ = ch; 
	if (head == tail)
		fwrite(buffer, 1, buffer_size, stdout), head = buffer; 
}

template <class T>
inline void putint(T x)
{
	static char buf[22]; 
	static char *tail = buf; 
	if (!x) return (void)(putChar('0')); 
	if (x < 0) x = ~x + 1, putChar('-'); 
	for (; x; x /= 10) *++tail = x % 10 + '0'; 
	for (; tail != buf; --tail) putChar(*tail); 
}

const int MaxN = 5e5 + 5; 
const int MaxS = 100001 * 18; 

struct edge
{
	int u, v, w; 
	inline void scan()
	{
		read(u), read(v), read(w); 
	}
	inline bool operator < (const edge &rhs) const
	{
		return w < rhs.w; 
	}
}e[MaxN]; 

struct request
{
	int u, x, k, num; 
	inline void scan(int t)
	{
		num = t; 
		read(u), read(x), read(k); 
	}
	inline bool operator < (const request &rhs) const
	{
		return x < rhs.x; 
	}
}req[MaxN]; 

int n, m, Q, tot, val_num, val[MaxN], real[MaxN]; 
int ans[MaxN], fa[MaxN], rt[MaxN], h[MaxN]; 
int lc[MaxS], rc[MaxS], sze[MaxS]; 

inline int ufs_find(int x)
{
	return x == fa[x] ? x : fa[x] = ufs_find(fa[x]); 
}

inline void insert(int &x, int l, int r, int pos)
{
	if (!x) x = ++tot; 
	++sze[x]; 
	if (l == r) return; 
	int mid = l + r >> 1; 
	
	if (pos <= mid)
		insert(lc[x], l, mid, pos); 
	else
		insert(rc[x], mid + 1, r, pos); 
}

inline int merge(int x, int y, int l, int r)
{
	if (!x || !y) return x + y; 
	int mid = l + r >> 1; 
	lc[x] = merge(lc[x], lc[y], l, mid); 
	rc[x] = merge(rc[x], rc[y], mid + 1, r);
	return sze[x] += sze[y], lc[y] = rc[y] = sze[y] = 0, x; 
}

inline int query(int x, int l, int r, int k)
{
	if (l == r) return l; 
	int mid = l + r >> 1, rsze = sze[rc[x]]; 
	return k <= rsze ? query(rc[x], mid + 1, r, k) : query(lc[x], l, mid, k - rsze); 
}

inline void link(int x, int y)
{
	int u = ufs_find(x), v = ufs_find(y); 
	if (u == v) return; 
	fa[v] = u; 
	val[u] += val[v]; 
	
	rt[u] = merge(rt[u], rt[v], 1, val_num); 
}

int main()
{
	read(n), read(m), read(Q); 
	for (int i = 1; i <= n; ++i)
		read(h[i]), real[++val_num] = h[i]; 
	std::sort(real + 1, real + val_num + 1); 
	val_num = std::unique(real + 1, real + val_num + 1) - real - 1; 
	for (int i = 1; i <= n; ++i)
		h[i] = std::lower_bound(real + 1, real + val_num + 1, h[i]) - real; 
	
	for (int i = 1; i <= m; ++i)
		e[i].scan(); 
	for (int i = 1; i <= Q; ++i)
		req[i].scan(i); 
	std::sort(e + 1, e + m + 1); 
	std::sort(req + 1, req + Q + 1); 
	
	for (int i = 1; i <= n; ++i)
	{
		fa[i] = i, val[i] = 1; 
		insert(rt[i], 1, val_num, h[i]); 
	}
	
	int r = 0; 
	for (int i = 1; i <= Q; ++i)
	{
		while (r < m && e[r + 1].w <= req[i].x)
		{
			++r; 
			link(e[r].u, e[r].v); 
		}
		
		int x = ufs_find(req[i].u); 
		int k = req[i].k; 
		ans[req[i].num] = k <= val[x] ? query(rt[x], 1, val_num, k) : 0; 
	}
	
	real[0] = -1; 
	for (int i = 1; i <= Q; ++i)
		putint(real[ans[i]]), putChar('\n'); 
	putChar('\0'); 
	return 0; 
}

1.3 BZOJ4552:[TJOI2016&HEOI2016]排序

  • 题目来源:BZOJ4552

题目大意:给出一个 1 1 1 n n n 的排列,现在对这个序列进行 m m m 次局部排序,排序分为两种:
1.将区间 [ l , r ] [l,r] [l,r] 的数字升序排序
2.将区间 [ l , r ] [l,r] [l,r] 的数字降序排序
最后询问 q q q 位置上的数字(询问只有最后一次)。 n , m ≤ 1 0 5 n,m\le 10^5 n,m105

  • 这道题有一个 O ( n log ⁡ 2 n ) O(n\log^2 n) O(nlog2n) 的套路做法(视 n , m n,m n,m 同阶)。
  • 二分最后答案 m i d mid mid,将 ≥ m i d \ge mid mid 的数标记为 1 1 1 < m i d <mid <mid 的数标记为 0 0 0
  • 01 01 01 的排序可以用维护区间 1 1 1 的个数和区间赋值实现,然后就知道这个数是否 ≥ m i d \ge mid mid
  • 但是我们今天讨论的重点当然不是这个。
  • 我们有一个 O ( n log ⁡ n ) O(n\log n) O(nlogn) 的资瓷在线询问的做法 (似乎可以出一道良心题)
  • 我们考虑将已经排好序(升序或降序)的区间看成一个整体处理,把有序的区间用平衡树维护,并记录这个区间是升序还是降序。对于每个区间包含的数,我们用一个权值线段树维护。
  • 暴力把每个操作涉及的区间取出,然后对于需要分裂的区间,我们分裂出一个区间的前 k k k 大或前 k k k,这个过程最多会涉及两个需要分裂的区间,每次分裂只会增加 O ( log ⁡ n ) O(\log n) O(logn) 个结点。
  • 然后把涉及的区间暴力依次合并
  • 时间复杂度:因为每次分裂只会增加两个区间,所以总的合并和分裂次数都是 O ( n ) O(n) O(n) 的。同样,我们可以用均摊分析,算出线段树合并的总时间复杂度就是 O ( n log ⁡ n ) O(n\log n) O(nlogn)
  • 平衡树可以用 s e t set set 减少代码量,总时间复杂度 O ( n log ⁡ n ) O(n\log n) O(nlogn)
  • 询问可以资瓷在线,只需要找到对应区间,在权值线段树上查询第 k k k 大和第 k k k 小即可。
  • 注意分裂区间的讨论细节。
  • 为了保证空间可以使用空间回收。
//O(nlogn)
#include 

inline char nextChar()
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer + buffer_size; 
	
	if (head == tail)
	{
		fread(buffer, 1, buffer_size, stdin); 
		head = buffer; 
	}
	return *head++; 
}

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = nextChar())); 
	x = ch - '0'; 
	while (isdigit(ch = nextChar()))
		x = x * 10 + ch - '0'; 
}

const int MaxN = 1e5 + 5; 
const int MaxS = MaxN * 20; 

struct node
{
	int l, r, opt, rt; 
	node(){}
	node(int a, int b, int c, int d):
		l(a), r(b), opt(c), rt(d) {}
	inline bool operator < (const node &rhs) const
	{
		return r < rhs.r; 
	}
	inline bool operator == (const node &rhs) const
	{
		return r == rhs.r; 
	}
}; 

typedef std::set<node>::iterator set_it; 

int n, Q, tot; 
int top, stk[MaxS]; 
int lc[MaxS], rc[MaxS], sze[MaxS]; 

std::set<node> S; 

inline int get_new()
{
	return top ? stk[top--] : ++tot; 
}

inline void del(int x)
{
	lc[x] = rc[x] = sze[x] = 0; 
	stk[++top] = x; 
}

inline void insert(int &x, int l, int r, int pos)
{
	if (!x) x = ++tot;  
	++sze[x]; 
	if (l == r) return; 
	
	int mid = l + r >> 1; 
	pos <= mid ? insert(lc[x], l, mid, pos) : insert(rc[x], mid + 1, r, pos); 
}

inline int merge(int x, int y, int l, int r)
{
	if (!x || !y) return x + y; 
	if (l == r) return sze[x] += sze[y], del(y), x; 
	int mid = l + r >> 1; 
	lc[x] = merge(lc[x], lc[y], l, mid); 
	rc[x] = merge(rc[x], rc[y], mid + 1, r); 
	sze[x] = sze[lc[x]] + sze[rc[x]]; 
	return del(y), x; 
}

inline int query(int x, int l, int r, int k, int opt)
{
	if (l == r) return l; 
	int mid = l + r >> 1; 
	if (opt)
	{
		int rs = sze[rc[x]]; 
		return k <= rs ? query(rc[x], mid + 1, r, k, opt) : query(lc[x], l, mid, k - rs, opt); 
	}
	else
	{
		int ls = sze[lc[x]]; 
		return k <= ls ? query(lc[x], l, mid, k, opt) : query(rc[x], mid + 1, r, k - ls, opt); 
	}
}

inline void split(int x, int l, int r, int k, int opt, int &a, int &b)
{
	if (!x) return (void)(a = b = 0); 
	if (!k) return (void)(a = 0, b = x); 
	if (l == r) return (void)(sze[b = k == sze[x] ? 0 : get_new()] = sze[x] - k, sze[a = x] = k); 
	
	int mid = l + r >> 1; 
	if (opt)
	{
		int rs = sze[rc[x]]; 
		if (k <= rs)
			a = get_new(), b = x, split(rc[x], mid + 1, r, k, opt, rc[a], rc[b]); 
		else
			a = x, b = get_new(), split(lc[x], l, mid, k - rs, opt, lc[a], lc[b]); 
	}
	else
	{
		int ls = sze[lc[x]]; 
		if (k <= ls)
			a = get_new(), b = x, split(lc[x], l, mid, k, opt, lc[a], lc[b]); 
		else
			a = x, b = get_new(), split(rc[x], mid + 1, r, k - ls, opt, rc[a], rc[b]); 
	}
	sze[a] = sze[lc[a]] + sze[rc[a]]; 
	sze[b] = sze[lc[b]] + sze[rc[b]]; 
}

int main()
{
	read(n), read(Q); 
	for (int i = 1; i <= n; ++i)
	{
		int now = 0, x; 
		read(x); 
		insert(now, 1, n, x); 
		S.insert(node(i, i, 0, now)); 
	}
	
	for (int i = 1; i <= Q; ++i)
	{
		int opt, l, r, u, v; 
		read(opt), read(l), read(r); 
		set_it it_l = S.lower_bound(node(l, l, 0, 0)); 
		set_it it_r = S.lower_bound(node(r, r, 0, 0)); 
		
		if (it_l == it_r)
		{
			int tl = it_l->l, tr = it_l->r, to = it_l->opt, tu = it_l->rt; 
			int u, v, w;  
			S.erase(it_l); 
			split(tu, 1, n, r - tl + 1, to, u, w); 
			split(u, 1, n, r - l + 1, to ^ 1, v, u); 
			
			if (tl < l) S.insert(node(tl, l - 1, to, u)); 
			S.insert(node(l, r, opt, v)); 
			if (r < tr) S.insert(node(r + 1, tr, to, w)); 
			
			continue; 
		}
		if (l != it_l->l)
		{
			split(it_l->rt, 1, n, l - it_l->l, it_l->opt, u, v); 	
			int tl = it_l->l, tr = it_l->r, to = it_l->opt; 
			S.erase(it_l); 
			S.insert(node(tl, l - 1, to, u)); 
			it_l = S.insert(node(l, tr, opt, v)).first; 
		}
		if (r != it_r->r)
		{
			split(it_r->rt, 1, n, r - it_r->l + 1, it_r->opt, u, v); 
			int tl = it_r->l, tr = it_r->r, to = it_r->opt; 
			S.erase(it_r); 
			S.insert(node(r + 1, tr, to, v)); 
			it_r = S.insert(node(tl, r, opt, u)).first; 
		}
		
		for (set_it lst = it_l, it = ++it_l; it != S.end() && it->r <= r; lst = it, ++it)
		{
			int u = merge(lst->rt, it->rt, 1, n); 
			int tl = lst->l, tr = it->r; 
			S.erase(lst), S.erase(it);  
			it = S.insert(node(tl, tr, opt, u)).first; 
		}
	}
	
	int q_pos; 
	read(q_pos); 
	set_it it = S.lower_bound(node(q_pos, q_pos, 0, 0)); 
	std::cout << query(it->rt, 1, n, q_pos - it->l + 1, it->opt) << std::endl; 
	return 0; 
}

1.4 可持久化线段树合并

  • 更多的时候,我们需要在线访问某次合并后的线段树,这就需要我们把合并的线段树的每个版本都记录下来。
  • 因为每次合并,我们只需要改动重合结点的信息,所以我们只需要新开这些结点,其他利用合并前的版本即可,根据上面时间复杂度的证明,可以保证空间复杂度的正确。
  • 常见的应用就是直接用可持久化线段树合并来维护出 S A M SAM SAM r i g h t right right 集合。

2. 可持久化线段树

  • 可持久化的思想就是每次新建一个版本时,没有必要全部重建,只需要基于上一个版本,把有变化的结点新建出来,其余结点直接利用上一个版本的。
  • 介绍可持久化线段树的文章很多,这里就不赘述了。这里主要讲一些套路题。

2.1 Codeforces 464E:The Classic Problem

  • 题目来源:CF464E 洛谷链接

题目大意:给定一个点数、边数在 1 0 5 10^5 105 级别的无向图,求 S S S T T T 的最短路,要求对 1 0 9 + 7 10^9+7 109+7 取模,并且输出最短路路径方案。每条边的边权形如 2 x ( 0 ≤ x ≤ 1 0 5 ) 2^x(0\le x\le 10^5) 2x(0x105)

  • 看这道最短路裸题,是不是很简单呀。
  • 我们考虑对于每个点用权值线段树维护该点的最短路长度信息。
  • 然后我们考虑在一个点的最短路上加一条边权,就相当于先把进位的部分全部置为 0 0 0,再把某一位置为 1 1 1。这个涉及的结点数较少,可以用可持久化线段树实现这样的修改操作。
  • 然后我们需要查询某一位开始一共有多少个连续的 1 1 1,可以维护区间极长的 1 1 1 后缀然后 O ( log ⁡ n ) O(\log n) O(logn) 查询,修改只需要把置为 0 0 0 的结点看成删除即可,同样是 O ( log ⁡ n ) O(\log n) O(logn)
  • 对于比较大小,我们对线段树的每个区间维护一个哈希值,比较时,先比较较高位的对应区间哈希值,然后根据是否相同递归比较大小,比较一次是 O ( log ⁡ n ) O(\log n) O(logn) 的。
  • 用小根堆优化 D i j k s t r a Dijkstra Dijkstra,所以时间复杂度 O ( m log ⁡ 2 n ) O(m\log^2n) O(mlog2n)
#include 

inline char nextChar()
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer + buffer_size; 
	
	if (head == tail)
	{
		fread(buffer, 1, buffer_size, stdin); 
		head = buffer; 
	}
	return *head++; 
}

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = nextChar())); 
	x = ch - '0'; 
	while (isdigit(ch = nextChar()))
		x = x * 10 + ch - '0'; 
}

template <class T>
inline void relax(T &x, const T &y)
{
	if (x < y) x = y; 
}

const int MaxNV = 1e5 + 555; 
const int MaxNE = 2e5 + 5; 
const int MaxS = MaxNV * 100; 

const int base = 2; 
const int mod1 = 1e9 + 7; 
const int mod2 = 1e9 + 9; 

#define change(x, y) lc[x] = lc[y], rc[x] = rc[y], len[x] = len[y], sze[x] = sze[y], suf[x] = suf[y], val1[x] = val1[y], val2[x] = val2[y]

struct halfEdge
{
	int v, w; 
	halfEdge *next; 
}adj_pool[MaxNE], *adj[MaxNV], *adj_tail = adj_pool; 

int n, m, src, des, lim, ans;
int p1[MaxNV], p2[MaxNV], rt[MaxNV], pre[MaxNV]; 

int tot, len[MaxS], val1[MaxS], val2[MaxS]; 
int lc[MaxS], rc[MaxS], sze[MaxS], suf[MaxS]; 

bool vis[MaxNV]; 

inline void addEdge(int u, int v, int w)
{
	adj_tail->v = v; 
	adj_tail->w = w; 
	adj_tail->next = adj[u]; 
	adj[u] = adj_tail++; 
}

inline void print(int x, int l, int r)
{
//	printf("prt %d:l = %d, r = %d, sze = %d, suf = %d\n:", x, l, r, sze[x], suf[x]); 
	if (!x || !sze[x]) return; 
	if (l == r) return (void)(ans = (ans + p1[l]) % mod1); 
	int mid = l + r >> 1; 
	print(lc[x], l, mid), print(rc[x], mid + 1, r); 
}

inline void upt(int x, int l, int r)
{
	int lenr = r - (l + r >> 1); 
	
	sze[x] = sze[lc[x]] + sze[rc[x]]; 
	suf[x] = sze[rc[x]] == lenr ? lenr + suf[lc[x]] : suf[rc[x]]; 
	
	val1[x] = (1LL * val1[lc[x]] * p1[lenr] + val1[rc[x]]) % mod1; 
	val2[x] = (1LL * val2[lc[x]] * p2[lenr] + val2[rc[x]]) % mod2; 
}

inline void insert(int lst, int &x, int l, int r, int pos)
{
	x = ++tot; 
	change(x, lst); 
	
	if (l == r)
	{
		sze[x] = suf[x] = val1[x] = val2[x] = 1; 
		return; 
	}
	
	int mid = l + r >> 1; 
	pos <= mid ? insert(lc[lst], lc[x], l, mid, pos) : insert(rc[lst], rc[x], mid + 1, r, pos); 
	upt(x, l, r); 
}

inline void del(int lst, int &x, int l, int r, int u, int v)
{
	if (u <= l && r <= v) return (void)(x = 0); 
	
	x = ++tot; 
	change(x, lst); 
	
	int mid = l + r >> 1; 
	if (u <= mid) del(lc[lst], lc[x], l, mid, u, v); 
	if (v > mid) del(rc[lst], rc[x], mid + 1, r, u, v); 
	
	upt(x, l, r); 
}

inline int query_suf(int x, int l, int r, int pos)
{
//	printf("query:%d %d %d %d %d\n", x, l, r, pos, suf[lc[x]]); 
	if (!x || !sze[x]) return std::max(l, pos); 
	int mid = l + r >> 1; 
	if (pos > mid || suf[lc[x]] >= mid - pos + 1)
		return query_suf(rc[x], mid + 1, r, std::max(pos, mid + 1)); 
	else
		return query_suf(lc[x], l, mid, pos); 
}

inline bool cmp(int x, int y, int l, int r)
{
	if (l == r) return sze[x] > sze[y]; 
	int mid = l + r >> 1; 
	if (val1[rc[x]] == val1[rc[y]] && val2[rc[x]] == val2[rc[y]])
		return cmp(lc[x], lc[y], l, mid); 
	else
		return cmp(rc[x], rc[y], mid + 1, r); 
}

struct node
{
	int pos, u; 
	node(){}
	node(int a, int b):
		pos(a), u(b) {}
	inline bool operator < (const node &rhs) const
	{
		return cmp(u, rhs.u, 0, lim); 
	}
}; 

std::priority_queue<node> heap; 

inline void dfs(int u, int dep)
{
	if (u == src)
	{
		printf("%d\n%d ", dep, u); 
		return; 
	}
	dfs(pre[u], dep + 1); 
	printf("%d ", u); 
}

int main()
{
	read(n), read(m); 
	for (int i = 1; i <= m; ++i)
	{
		int u, v, w; 
		read(u), read(v), read(w); 
		addEdge(u, v, w); 
		addEdge(v, u, w); 
		relax(lim, w); 
	}
	lim += 200; 
	read(src), read(des); 
	
	p1[0] = p2[0] = 1; 
	for (int i = 1; i <= lim; ++i)
		p1[i] = 2LL * p1[i - 1] % mod1, p2[i] = 2LL * p2[i - 1] % mod2; 
	
	rt[src] = tot = 1, len[1] = lim + 1; 
	heap.push(node(src, 1)); 
	while (!heap.empty())
	{
		node now = heap.top(); 
		heap.pop(); 
		
		int u = now.pos; 
		if (now.u != rt[u]) continue; 
		
	//	ans = 0; 
	//	printf("u:::%d-----\n", u); 
	//	print(rt[u], 0, lim); 
	//	printf("ans:%d\n", ans); 
		
		vis[u] = true; 
		if (u == des)
		{
			print(rt[des], 0, lim); 
			printf("%d\n", ans); 
			dfs(des, 1); 
			return 0; 
		}
		
		for (halfEdge *e = adj[u]; e; e = e->next)
			if (!vis[e->v])
			{
				int nxt; 
			//	printf("(%d %d : %d %d\n", u, e->v, e->w, query_suf(rt[u], 0, lim, e->w)); 
				int pos = query_suf(rt[u], 0, lim, e->w); 
				insert(rt[u], nxt, 0, lim, query_suf(rt[u], 0, lim, e->w)); 
				if (e->w < pos)
					del(nxt, nxt, 0, lim, e->w, pos - 1); 
				if (!rt[e->v] || cmp(rt[e->v], nxt, 0, lim))
					heap.push(node(e->v, rt[e->v] = nxt)), pre[e->v] = u; 
			}
	}
	puts("-1"); 
	
	return 0; 
}

2.2 Codeforces 893F:Subtree Minimum Query

  • 题目来源:CF893F 洛谷链接

题目大意:给你一棵 n n n 个点的有根树,点有权值, m m m 次询问,每次问你某个点 u u u 的子树中距离其不超过 k k k 的点的权值的最小值。(边权均为 1 1 1,点权有可能重复, k k k 值每次询问有可能不同,强制在线 n , m ≤ 1 0 5 , a i ≤ 1 0 9 n,m\le 10^5,a_i\le10^9 n,m105,ai109

  • 简单套路题。
  • 先不考虑深度的问题,那么我们只需要查询子树内的最小点权,用 d f s dfs dfs 序在线段树上维护最值即可。
  • 考虑深度,那么实际上就是查询 d e p [ v ] ≤ d e p [ u ] + k dep[v]\le dep[u]+k dep[v]dep[u]+k 并且在 u u u 的子树中的最小点权。
  • 我们可以对每个 d e p ≤ i dep\le i depi 建立一个线段树,因为 i − 1 i-1 i1 i i i 只需要增加几个结点,总的修改数又是 O ( n ) O(n) O(n) 的,直接可持久化即可。
  • 十分简单,时间复杂度 O ( n log ⁡ n ) O(n\log n) O(nlogn)
#include 

inline char nextChar()
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer + buffer_size; 
	
	if (head == tail)
	{
		fread(buffer, 1, buffer_size, stdin); 
		head = buffer; 
	}
	return *head++; 
}

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = nextChar())); 
	x = ch - '0'; 
	while (isdigit(ch = nextChar()))
		x = x * 10 + ch - '0'; 
}

inline void putChar(char ch)
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer; 
	
	if (ch == '\0')
		fwrite(buffer, 1, head - buffer, stdout); 
	
	*head++ = ch; 
	if (head == tail)
		fwrite(buffer, 1, buffer_size, stdout), head = buffer; 
}

template <class T>
inline void putint(T x)
{
	static char buf[15]; 
	static char *tail = buf; 
	
	if (!x) putChar('0'); 
	else
	{
		for (; x; x /= 10) *++tail = x % 10 + '0'; 
		for (; tail != buf; --tail) putChar(*tail); 
	}
}

template <class T>
inline void relax(T &x, const T &y)
{
	if (x < y) x = y; 
}

template <class T>
inline void tense(T &x, const T &y)
{
	if (x > y) x = y; 
}

const int MaxNV = 1e5 + 5; 
const int MaxNE = MaxNV << 1; 
const int MaxS = MaxNV * 40; 

const int INF = 0x7fffffff; 

struct halfEdge
{
	int v; 
	halfEdge *next; 
}adj_pool[MaxNE], *adj_tail = adj_pool, *adj[MaxNV]; 

int n, Q; 
int rt, dfs_clock, max_dep; 

int ldfn[MaxNV], rdfn[MaxNV]; 
int a[MaxNV], dep[MaxNV]; 
std::vector<int> S[MaxNV]; 

int tot, seg[MaxNV]; 
int lc[MaxS], rc[MaxS], val[MaxS]; 

inline void copy(int x, int y)
{
	lc[x] = lc[y], rc[x] = rc[y], val[x] = val[y]; 
}

inline void addEdge(int u, int v)
{
	adj_tail->v = v; 
	adj_tail->next = adj[u]; 
	adj[u] = adj_tail++; 
}

inline void dfs_init(int u, int pre)
{
	ldfn[u] = ++dfs_clock; 
	
	relax(max_dep, dep[u] = dep[pre] + 1); 
	S[dep[u]].push_back(u); 
	
	for (halfEdge *e = adj[u]; e; e = e->next)
		if (e->v != pre)
			dfs_init(e->v, u); 
	
	rdfn[u] = dfs_clock; 
}

inline void insert(int lst, int &x, int l, int r, int pos, int del)
{
	copy(x = ++tot, lst); 
	tense(val[x], del); 
	
	if (l == r) return; 
	int mid = l + r >> 1; 
	if (pos <= mid)
		insert(lc[lst], lc[x], l, mid, pos, del); 
	else
		insert(rc[lst], rc[x], mid + 1, r, pos, del); 
}

inline int query_min(int x, int l, int r, int u, int v)
{
	if (!x) return INF; 
	if (u <= l && r <= v) return val[x]; 
	
	int mid = l + r >> 1, res = INF; 
	if (u <= mid)
		tense(res, query_min(lc[x], l, mid, u, v)); 
	if (v > mid)
		tense(res, query_min(rc[x], mid + 1, r, u, v)); 
	
	return res; 
}

int main()
{
	val[0] = INF; 
	
	read(n), read(rt); 
	for (int i = 1; i <= n; ++i)
		read(a[i]); 
	for (int i = 1; i < n; ++i)
	{
		int u, v; 
		read(u), read(v);
		addEdge(u, v), addEdge(v, u); 
	}
	
	dfs_init(rt, 0); 
	
	for (int i = 1; i <= max_dep; ++i)
	{
		std::vector<int> &T = S[i]; 
		
		int lst = seg[i - 1]; 
		for (int j = 0, jm = T.size(); j < jm; ++j)
		{
			int u = T[j]; 
			insert(lst, lst, 1, n, ldfn[u], a[u]); 
		}
		seg[i] = lst; 
	}
	
	read(Q); 
	
	int last_ans = 0; 
	while (Q--)
	{
		int u, k; 
		read(u), read(k); 
		u = (u + last_ans) % n + 1; 
		k = (k + last_ans) % n; 
		
		int d = std::min(max_dep, dep[u] + k); 
		
		putint(last_ans = query_min(seg[d], 1, n, ldfn[u], rdfn[u])); 
		putChar('\n'); 
	}
	
	putChar('\0'); 
	return 0; 
}

2.3 BZOJ4771:七彩树

  • 题目来源:BZOJ4771

一个 n n n 个结点的有根树,每个点有一个颜色, m m m 次询问,每次询问某一结点 u u u 中,深度不超过某个值 d e p [ u ] + d dep[u]+d dep[u]+d 的所有结点的颜色总数。对于每次询问 u , d u,d u,d 不一定相同。强制在线
n , m ≤ 100000 n,m\le 100000 n,m100000
多组询问, 所有询问的 n , m n,m n,m 之和不超过 500000 500000 500000

  • 这题和上面那题很像。
  • 同样先从没有限制深度的问题先考虑。
  • 首先我们要知道一个关于子树数颜色的小套路:我们知道所有颜色为 c c c 的结点的贡献,即他们到根结点的路径并
  • 这种路径并的贡献,也有一个套路求法,就是我们把颜色相同的结点取出来,先在各自的位置 + 1 +1 +1。接着按照 d f s dfs dfs 序排好,然后在 d f s dfs dfs 序相邻的结点的 l c a lca lca 的位置 − 1 -1 1。询问直接子树求和即可。
  • 至于路径并为啥能 d f s dfs dfs 排序这么弄,证明就不说了。网上应该挺多证明的
  • 然后这样我们直接用线段树维护子树和即可。
  • 对于限制深度的,我们同样考虑按照深度可持久化。
  • 我们对每个颜色维护一个 s e t set set,新插进来的某种颜色的结点 u u u,我们就找到已经插入过的结点中, d f s dfs dfs u u u 的前驱和后继的结点,然后利用 l c a lca lca 在线段树上改就好了。
  • 查询即在某个版本区间求和
  • 时间复杂度仍然是 O ( n log ⁡ n ) O(n\log n) O(nlogn)
#include 

inline char nextChar()
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer + buffer_size; 
	
	if (head == tail)
	{
		fread(buffer, 1, buffer_size, stdin); 
		head = buffer; 
	}
	return *head++; 
}

template <class T>
inline void read(T &x)
{
	static char ch; 
	while (!isdigit(ch = nextChar())); 
	x = ch - '0'; 
	while (isdigit(ch = nextChar()))
		x = x * 10 + ch - '0'; 
}

inline void putChar(char ch)
{
	static const int buffer_size = 2333333; 
	static char buffer[buffer_size]; 
	static const char *tail = buffer + buffer_size; 
	static char *head = buffer; 
	
	if (ch == '\0')
		fwrite(buffer, 1, head - buffer, stdout); 
	
	*head++ = ch; 
	if (head == tail)
		fwrite(buffer, 1, buffer_size, stdout), head = buffer; 
}

template <class T>
inline void putint(T x)
{
	static char buf[22]; 
	static char *tail = buf; 
	if (!x) return (void)(putChar('0')); 
	if (x < 0) x = ~x + 1, putChar('-'); 
	for (; x; x /= 10) *++tail = x % 10 + '0'; 
	for (; tail != buf; --tail) putChar(*tail); 
}

template <class T>
inline void relax(T &x, const T &y)
{
	if (x < y) x = y; 
}

typedef std::set<int>::iterator set_it; 

#define copy(x, y) sum[x] = sum[y], lc[x] = lc[y], rc[x] = rc[y]

const int MaxNV = 1e5 + 5; 
const int MaxNE = MaxNV; 
const int MaxLog = 18; 
const int MaxS = MaxNV * MaxLog * 5; 

struct halfEdge
{
	int v; 
	halfEdge *next; 
}adj_pool[MaxNE], *adj_tail, *adj[MaxNV]; 

int n, m, last_ans, dfs_clock, max_dep = 0; 
int col[MaxNV], ldfn[MaxNV], rdfn[MaxNV], idx[MaxNV], dep[MaxNV]; 

int anc[MaxNV][MaxLog + 1]; 

std::set<int> S[MaxNV]; 
std::vector<int> vec[MaxNV]; 
int tot, rt[MaxNV]; 
int lc[MaxS], rc[MaxS], sum[MaxS]; 

inline void addEdge(int u, int v)
{
	adj_tail->v = v; 
	adj_tail->next = adj[u]; 
	adj[u] = adj_tail++; 
}

inline void init()
{
	for (int i = 1; i <= tot; ++i)
		lc[i] = rc[i] = sum[i] = 0; 
	
	last_ans = max_dep = tot = dfs_clock = 0; 
	adj_tail = adj_pool; 
	for (int i = 1; i <= n; ++i)
	{
		adj[i] = NULL, vec[i].clear(), S[i].clear(); 
		
		rt[i] = 0; 
		for (int j = 0; j <= MaxLog; ++j)
			anc[i][j] = 0; 
	}
}

inline int query_lca(int u, int v)
{
	if (dep[u] < dep[v]) std::swap(u, v); 
	for (int d = dep[u] - dep[v], i = 0; d; d >>= 1, ++i)
		if (d & 1)
			u = anc[u][i]; 
	if (u == v) return u; 
	for (int i = MaxLog; i >= 0; --i)
		if (anc[u][i] != anc[v][i])
		{
			u = anc[u][i]; 
			v = anc[v][i]; 
		}
	return anc[u][0]; 
}

inline void dfs_init(int u)
{
	idx[ldfn[u] = ++dfs_clock] = u;  
	vec[dep[u] = dep[anc[u][0]] + 1].push_back(u); 
	
	relax(max_dep, dep[u]); 
	
	for (int i = 0; anc[u][i]; ++i)
		anc[u][i + 1] = anc[anc[u][i]][i]; 
	
	for (halfEdge *e = adj[u]; e; e = e->next)
		dfs_init(e->v); 
	rdfn[u] = dfs_clock; 
}

inline void insert(int lst, int &x, int l, int r, int pos, int opt)
{
	x = ++tot, copy(x, lst); 
	sum[x] += opt; 
	
	if (l == r) return; 
	
	int mid = l + r >> 1; 
	if (pos <= mid)
		insert(lc[lst], lc[x], l, mid, pos, opt); 
	else
		insert(rc[lst], rc[x], mid + 1, r, pos, opt); 
}

inline int query_sum(int x, int l, int r, int u, int v)
{
	if (!x) return 0; 
	if (u <= l && r <= v) return sum[x]; 
	int mid = l + r >> 1, res = 0; 
	if (u <= mid)
		res += query_sum(lc[x], l, mid, u, v); 
	if (v > mid)
		res += query_sum(rc[x], mid + 1, r, u, v); 
	return res; 
}

int main()
{
	int T; 
	read(T); 
	while (T--)
	{
		read(n), read(m), init(); 
		for (int i = 1; i <= n; ++i)
			read(col[i]); 
		for (int i = 2; i <= n; ++i)
		{
			read(anc[i][0]); 
			addEdge(anc[i][0], i); 
		}
		
		dfs_init(1); 
		
		for (int i = 1; i <= max_dep; ++i)
		{
			std::vector<int> &T = vec[i]; 
			
			int lst = rt[i - 1]; 
			for (int j = 0, jm = T.size(); j < jm; ++j)
			{
				int v = T[j]; 
				
				set_it it1 = S[col[v]].insert(ldfn[v]).first, it2 = it1; ++it2; 
				insert(lst, lst, 1, n, ldfn[v], 1); 
				
				bool has_pre = it1 != S[col[v]].begin(); 
				bool has_suf = it2 != S[col[v]].end(); 
				
				set_it tmp = it1; 
				
				if (has_pre)
					insert(lst, lst, 1, n, ldfn[query_lca(idx[*--it1], v)], -1); 
				if (has_suf)
					insert(lst, lst, 1, n, ldfn[query_lca(idx[*it2], v)], -1); 
				if (has_pre && has_suf)
					insert(lst, lst, 1, n, ldfn[query_lca(idx[*it1], idx[*it2])], 1); 
			}
			
			rt[i] = lst; 
		}
		
		for (int i = 1; i <= m; ++i)
		{
			int u, d; 
			read(u), read(d); 
			u ^= last_ans, d ^= last_ans; 
			
			d = std::min(max_dep, d + dep[u]); 
			
			putint(last_ans = query_sum(rt[d], 1, n, ldfn[u], rdfn[u])); 
			putChar('\n'); 
		}
	}
	
	putChar('\0'); 
	return 0; 
} 

你可能感兴趣的:(学习笔记)