树套树 (线段树+splay)

树套树,就是线段树、平衡树、树状数组等数据结构的嵌套。

最简单的是线段树套set,可以解决一些比较简单的问题,而且代码根线段树是一样的只是一些细节不太一样。

本题中用的是线段树套splay,代码较长。

树套树中的splay和单一的splay原理是一样的,只不过是建了很多的splay树,因为不止一个,所以跟板子不同的是,大部分函数都要传splay的根节点规定起点。

而线段树中存储的就是每个区间对应的splay的root节点。

只要线段树和splay板子都懂了,这一题就很好理解。

树套树 (线段树+splay)_第1张图片

const int mod = 1e9 + 7, INF = 2147483647;
const int N = 1e7+ 10;
int n, m;
struct Node {
	int s[2], p, v; // 左右儿子、父节点、值
	int size, cnt; // 子树大小、懒标记
	void init(int _v, int _p) { // 初始化函数
		v = _v, p = _p;
		cnt = size =  1;
	}
} tr[N];
int L[N], R[N], T[N], idx;
int w[N];

void pushup(int u) { // 向上更新传递,与线段树一样
	tr[u].size = tr[tr[u].s[0]].size + tr[tr[u].s[1]].size + tr[u].cnt;
}

void rotate(int x) { // 核心函数
	int y = tr[x].p, z = tr[y].p;
	int k = tr[y].s[1] == x;
	tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
	tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
	tr[x].s[k ^ 1] = y, tr[y].p = x;
	pushup(y), pushup(x);
}

void splay(int& root, int x, int k) { // 将x节点旋转到k节点下
	while(tr[x].p != k) { //
		int y = tr[x].p; // x节点的父节点
		int z = tr[y].p; // x节点的父节点的父节点
		if(z != k) // 向上旋转
			if((tr[y].s[1] == x) != (tr[z].s[1] == y)) rotate(x); // 转一次x
			else rotate(y); // 转一次y
		rotate(x); // 转一次x
	}
	if(!k) root = x; // 更新root节点
}

void upper(int& root, int v) { // 将v值节点转到根节点
	int u = root; // 根节点
	while(tr[u].s[v > tr[u].v] && tr[u].v != v) // 存在则找到v值节点,不存在则找到v值节点的前驱或者后继节点
		u = tr[u].s[v > tr[u].v]; // 向下寻找
	splay(root, u, 0); // 将u节点旋转到跟节点
}

int get_prev(int& root, int v) { // 获取v值的前驱节点,严格小于v的最大节点
	upper(root, v); // 将v值节点转到根节点
	if(tr[root].v < v) return root; // 若是该值在树中不存在,根节点就是v的前驱或者后继节点
	int u = tr[root].s[0]; // 前驱节点在左子树的最右边
	while(tr[u].s[1]) u = tr[u].s[1]; // 找到最右边的一个节点
	return u;
}

int get_next(int& root, int v) { // 获取某值的后继节点,严格大于v的最小节点
	upper(root, v); // 将v值节点转到根节点
	if(tr[root].v > v) return root; // 若是该值在树中不存在,根节点就是v的前驱或者后继节点
	int u = tr[root].s[1]; // 后继节点在右子树的最左边
	while(tr[u].s[0]) u = tr[u].s[0]; // 找到最左的节点,就是最小的节点
	return u; // 返回节点
}

void insert(int& root, int v) { // 在二叉树中插入一个值
	int u = root, p = 0; // p维护为当前节点的父节点
	while(u && tr[u].v != v) // 没找到则一直向下寻找
		p = u, u = tr[u].s[v > tr[u].v]; // 更新父节点,更新当前节点
	if(u) tr[u].cnt ++; // v值的节点已经存在则直接加一即可
	else { // 不存在则创建节点
		u = ++ idx; // 分配节点序号
		if(p) tr[p].s[v > tr[p].v] = u; // 将父节点也就是前驱节点指向当前节点
		tr[u].init(v, p); // 初始化当前节点的值、父节点信息
	}
	splay(root, u, 0); // 将u节点旋转到根节点下
}

int get_k(int root, int v) { // 获得树中有多少比v小的数
	int u = root, res = 0;
	while(u) {
		if(tr[u].v < v) res += tr[tr[u].s[0]].size + tr[u].cnt, u = tr[u].s[1];
		else u = tr[u].s[0];
	}
	return res;
}

void remove(int& root, int v) { // 删除一个值为v的节点
	int prev = get_prev(root, v), nex = get_next(root, v); // 获取该节点的前驱以及后继节点。
	splay(root, prev, 0), splay(root, nex, prev); // 将前继节点旋转到根节点,将后继节点旋转到前驱节点下面也就是根节点下面
	int w = tr[nex].s[0]; // 后继节点的左子树就是v的节点
	if(tr[w].cnt > 1) tr[w].cnt --, splay(root, w, 0); // 该节点的v不止存在一个,减一,w节点旋转到根节点
	else tr[nex].s[0] = 0, splay(root, nex, 0); // 唯一,那么直接把后继节点的左子树指向空也就是0即可
}

void update(int& root, int x, int y) { // 将一个x值点改为y值
	remove(root, x); // 先删除
	insert(root, y); // 再插入
}

void build(int u, int l, int r) {
	L[u] = l, R[u] = r; // 存储某个节点的左右边界
	insert(T[u], -INF), insert(T[u], INF); // 插入哨兵
	for(int i = l; i <= r; i ++) insert(T[u], w[i]); // 初始化线段树每个节点的平衡树
	if(l == r) return ;
	int mid = l + r >> 1;
	build(u << 1, l, mid); // 建左子树
	build(u << 1 | 1, mid + 1, r); // 建右子树
}


int query(int u, int a, int b, int x) { // 查询区间a,b之间有多少比x值小的数
	if(a <= L[u] && R[u] <= b)  return get_k(T[u], x) - 1;
	int mid = L[u] + R[u] >> 1, res = 0;
	if(a <= mid) res += query(u << 1, a, b, x); // 查询左子树中有多少是该区间并且小于x的数
	if(mid < b) res += query(u << 1 | 1, a, b, x); // 查询右子树中有多少是该区间并且小于x的数
	return res;
}

void change(int u, int p, int x) { // 将线段树中p位置数值改为x
	update(T[u], w[p], x); // 修改当前节点中平衡树中的值
	if(L[u] == R[u]) return ;
	int mid = L[u] + R[u] >> 1;
	if(p <= mid) change(u << 1, p, x); // 修改左子树
	else change(u << 1 | 1, p, x); // 修改右子树
}

int query_prev(int u, int a, int b, int x) { // 查询再该区间中x的前驱节点
	if(a <= L[u] && R[u] <= b) return tr[get_prev(T[u], x)].v; // 该函数为查找当前子树中x的前驱节点
	int mid = L[u] + R[u] >> 1, res = -INF;
	if(a <= mid) res = max(res, query_prev(u << 1, a, b, x)); // 递归左子树
	if(mid < b) res = max(res, query_prev(u << 1 | 1, a, b, x)); // 递归右子树
	return res; // 返回左右子树中的最大值
}

int query_next(int u, int a, int b, int x) { // 查询再该区间中x的后继节点
	if(a <= L[u] && R[u] <= b)  return tr[get_next(T[u], x)].v; // 该函数为查找当前子树中x的后继节点
	int mid = L[u] + R[u] >> 1, res = INF;
	if(a <= mid) res = min(res, query_next(u << 1, a, b, x));
	if(mid < b) res = min(res, query_next(u << 1 | 1, a, b, x));
	return res; // 返回左右子树中的最小值
}

int get_rank_to_tr(int a, int b, int x) { // 查找区间内排名第x的数 
	int l = 0, r = 1e8;
	while(l < r) { // 通过二分获得答案,因为只能判断某个数在区间内的排名。 
		int mid = l + r + 1 >> 1;
		if(query(1, a, b, mid) + 1 <= x) l = mid; // 
		else r = mid - 1;
	}
	return r;
}

inline void sovle() {
	cin >> n >> m;
	for(int i = 1; i <= n; i ++)cin >> w[i];
	build(1, 1, n);
	while(m --) {
		int op, a, b, x;
		cin >> op >> a >> b;
		if(op != 3) cin >> x;
		if(op == 1) cout << query(1, a, b, x) + 1 << endl;
		if(op == 2) cout << get_rank_to_tr(a, b, x) << endl;
		if(op == 3) {
			change(1, a, b);
			w[a] = b;
		}
		if(op == 4) cout << query_prev(1, a, b, x) << endl;
		if(op == 5) cout << query_next(1, a, b, x) << endl;
	}

}

你可能感兴趣的:(算法,数据结构,图论)