树套树P3380

开O2过了

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 


#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;


inline int read(int& x) {
	char ch = getchar();
	int f = 1; x = 0;
	while (ch > '9' || ch < '0') { if (ch == '-')f = -1; ch = getchar(); }
	while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + ch - '0'; ch = getchar(); }
	return x * f;
}
//void ReadFile() {
//	FILE* stream1;
//	freopen_s(&stream1,"in.txt", "r", stdin);
//	freopen_s(&stream1,"out.txt", "w", stdout);
//}

static auto speedup = []() {ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); return nullptr; }();


template <class _Type>
class AVL
{
public:
	template <class _T>
	struct AVLNode {
		//AVLNode(_T input) :l(nullptr), r(nullptr), h(1), data(input) {}
		AVLNode() :data(), l(nullptr), r(nullptr), h(1), c(1), sub(1) {}
		_T data;
		AVLNode* l, * r;
		int h, c, sub;//高度 副本个数 子树节点个数
	};
	typedef AVLNode<_Type>  value_Type;
	typedef AVLNode<_Type>* point;
public:
	AVL() :m_root(nullptr), m_tot(0) {}
	~AVL() { clear(m_root); }
public:
	typename AVL<_Type>::point GetNode(_Type&& input);

	int  GetHight(typename AVL<_Type>::point node) {
		if (node == nullptr) return 0;
		return node->h;
	}
	int  GetSubCnt(typename AVL<_Type>::point node) {
		if (node == nullptr) return 0;
		return node->sub;
	}
	void flush(typename AVL<_Type>::point node) {
		if (node == nullptr) return;
		node->h = max(GetHight(node->l), GetHight(node->r)) + 1;
		node->sub = GetSubCnt(node->l) + GetSubCnt(node->r) + node->c;
	}
public:
	typename AVL<_Type>::point insert(_Type input);//插入
	typename AVL<_Type>::point _insert(typename AVL<_Type>::point cur, _Type&& input);

	void remove(_Type input);//删除
	typename AVL<_Type>::point _remove(typename AVL<_Type>::point cur, _Type&& input);

	typename AVL<_Type>::point findKth(int k);	//查找第k大
	int findRank(_Type x);
	typename AVL<_Type>::point findPre(_Type x);	//查找x前驱
	typename AVL<_Type>::point findNext(_Type x);//查找x后继
	typename AVL<_Type>::point find(_Type x);//查找x

public:
	void  rotate_left(typename AVL<_Type>::point& input);	//左旋转
	void  rotate_right(typename AVL<_Type>::point& input);	//右旋转
	typename AVL<_Type>::point revise(typename AVL<_Type>::point& node);//修正
	void clear(typename AVL<_Type>::point& input);
	void print(typename AVL<_Type>::point& input) {
		if (input == nullptr) return;
		print(input->l);
		m_data.insert(m_data.end(), input->c, input->data);
		cout << input->data << " ";
		print(input->r);
	}
public:
	typename AVL<_Type>::point   m_root;
	typename AVL<_Type>::point   m_new;//最后一个新增节点
	int		m_tot;//总节点个数
	vector<_Type> m_data;
};

template <class _Type>
typename AVL<_Type>::point AVL<_Type>::GetNode(_Type&& input) {
	AVL<_Type>::point ret = new AVL<_Type>::value_Type();
	if (ret)ret->data = input, m_tot++;
	return m_new = ret;
}

template <class _Type>
void AVL<_Type>::clear(typename AVL<_Type>::point& input) {
	if (input == nullptr) return;
	clear(input->l);
	clear(input->r);
	input->l = nullptr;
	input->r = nullptr;
	input = nullptr;
}

template <class _Type>
typename AVL<_Type>::point AVL<_Type>::insert(_Type input)//插入
{
	m_new = nullptr;
	m_root = _insert(m_root, std::forward<_Type>(input));
	return m_new;
}

template <class _Type>
typename AVL<_Type>::point AVL<_Type>::_insert(typename AVL<_Type>::point p, _Type&& input)//插入
{
	//找到节点位置
	if (p == nullptr) return GetNode(std::forward<_Type>(input));

	//相同元素 标记数量
	if (p->data == input) {
		p->c++;
		p->sub++;
		return p;
	}
	if (p->data > input)p->l = _insert(p->l, std::forward<_Type>(input));
	else p->r = _insert(p->r, std::forward<_Type>(input));

	return revise(p);
}

template <class _Type>
typename AVL<_Type>::point AVL<_Type>::revise(typename AVL<_Type>::point& p) {
	if (p == nullptr) return p;
	int hl = GetHight(p->l), hr = GetHight(p->r), diff = hl - hr;
	if (diff > 1) {
		//LL
		if (GetHight(p->l->l) >= GetHight(p->l->r)) {
			rotate_right(p);
		}
		else {
			//LR
			rotate_left(p->l);
			rotate_right(p);
		}
	}
	else if (diff < -1) {
		//RL
		if (GetHight(p->r->l) >= GetHight(p->r->r)) {
			rotate_right(p->r);
			rotate_left(p);
		}
		else rotate_left(p);//RR
	}
	flush(p);
	return p;
}

template <class _Type>
void AVL<_Type>::remove(_Type input)//删除
{
	m_root = _remove(m_root, std::forward<_Type>(input));
}

template <class _Type>
typename AVL<_Type>::point AVL<_Type>::_remove(typename AVL<_Type>::point p, _Type&& input)
{
	if (p == nullptr) return p;
	if (p->data > input) {
		p->l = _remove(p->l, std::forward<_Type>(input));
	}
	else if (p->data < input) {
		p->r = _remove(p->r, std::forward<_Type>(input));
	}
	else {
		//已找到节点 有多个副本
		if (--p->c > 0) { m_tot--; return revise(p); }
		else {
			//有左右子树
			if (p->l && p->r) {
				//寻找前驱节点 一定有前驱和后继
				AVL<_Type>::point preplace = p->l;
				while (preplace->r) preplace = preplace->r;
				//交换数据
				swap(p->data, preplace->data);
				//副本数量
				swap(p->c, preplace->c);

				p->l = _remove(p->l, std::forward<_Type>(input));
			}
			else {
				m_tot--;
				AVL<_Type>::point preplace = p->l ? p->l : p->r;
				delete p;
				return preplace;
			}
		}
	}
	return revise(p);
}

template <class _Type>
void  AVL<_Type>::rotate_left(typename AVL<_Type>::point& node)//旋转
{
	if (node == nullptr || node->r == nullptr) return;
	AVL::point r = node->r;
	AVL::point rl = r->l;
	r->l = node;
	node->r = rl;
	node = r;

	flush(node->l);
	flush(node);
}

template <class _Type>
void  AVL<_Type>::rotate_right(typename AVL<_Type>::point& node)//右旋转
{
	if (node == nullptr || node->l == nullptr) return;
	AVL::point l = node->l;
	AVL::point lr = l->r;
	l->r = node;
	node->l = lr;
	node = l;
	flush(node->r);
	flush(node);
}

template <class _Type>
typename AVL<_Type>::point AVL<_Type>::findKth(int k)	//查找第k大
{
	if (m_root == nullptr || k > m_tot) return nullptr;
	AVL<_Type>::point cur = m_root;
	while (cur) {
		int l = GetSubCnt(cur->l);
		int r = GetSubCnt(cur->r);
		int m = cur->c;
		if (k <= l) {
			cur = cur->l;
		}
		else if (k <= l + m) {
			return cur;
		}
		else {
			k -= l + m;
			cur = cur->r;
		}
	}
	return nullptr;
}

template <class _Type>
int AVL<_Type>::findRank(_Type x) {
	AVL<_Type>::point cur = m_root;

	int ans = 0;
	while (cur) {
		if (cur->data > x) {
			cur = cur->l;
		}
		else if (cur->data < x) {
			ans += GetSubCnt(cur->l) + cur->c;
			cur = cur->r;
		}
		else break;
	}
	if (cur == nullptr) return ans + 1;
	return ans + GetSubCnt(cur->l) + 1;
}


template <class _Type>
typename AVL<_Type>::point AVL<_Type>::findPre(_Type x)	//查找x前驱
{
	AVL<_Type>::point cur = m_root, ret = nullptr;
	while (cur) {
		if (cur->data < x)ret = cur, cur = cur->r;
		else cur = cur->l;
	}
	return ret;
}

template <class _Type>
typename AVL<_Type>::point AVL<_Type>::findNext(_Type x)//查找x后继
{
	AVL<_Type>::point cur = m_root, ret = nullptr;
	while (cur) {
		if (cur->data > x)ret = cur, cur = cur->l;
		else cur = cur->r;
	}
	return ret;
}

template <class _Type>
typename AVL<_Type>::point AVL<_Type>::find(_Type x)//查找x
{
	AVL<_Type>::point cur = m_root;
	while (cur) {
		if (cur->data == x) return cur;
		else if (cur->data > x) cur = cur->l;
		else cur = cur->r;
	}
	return nullptr;
}



template <typename _Type>
class Splay {
private:
	template <typename _T>
	struct SplayNode {
		SplayNode() :data(), l(nullptr), r(nullptr), fa(nullptr), c(1), sub(1) {}
		_T data;
		SplayNode* l, * r, * fa;
		int c, sub;//副本个数 子树节点个数
	};
public:
	typedef SplayNode<_Type>  value_Type;
	typedef SplayNode<_Type>* point;
public:
	Splay() :m_root(nullptr), m_tot(0) {}
	~Splay() { clear(m_root); }
	void clear(typename Splay<_Type>::point& input);
public:
	int  GetSubCnt(typename Splay<_Type>::point node) {
		if (node == nullptr) return 0;
		return node->sub;
	}
	void flush(typename Splay<_Type>::point node) {
		if (node == nullptr) return;
		node->sub = GetSubCnt(node->l) + GetSubCnt(node->r) + node->c;
	}
	typename Splay<_Type>::point GetNode(_Type&& input);
public:
	typename Splay<_Type>::point insert(_Type input);//插入
	void remove(_Type input);//删除
	typename Splay<_Type>::point find(_Type x);//查找x
	typename Splay<_Type>::point splay(typename  Splay<_Type>::point& input, typename  Splay<_Type>::point target = nullptr);//
	typename Splay<_Type>::point findKth(int k);	//查找第k大
	int findRank(_Type x);
	typename Splay<_Type>::point findPre(_Type x);	//查找x前驱
	typename Splay<_Type>::point findNext(_Type x);//查找x后继
	void reverse(int l, int r);
public:
	void  rotate_left(typename  Splay<_Type>::point& input);	//左旋转
	void  rotate_right(typename Splay<_Type>::point& input);	//右旋转
private:
	typename Splay<_Type>::point   m_root;
	int m_tot;
};

//
template <class _Type>
typename Splay<_Type>::point Splay<_Type>::GetNode(_Type&& input) {
	Splay<_Type>::point ret = new Splay<_Type>::value_Type();
	if (ret)ret->data = input, m_tot++;
	return ret;
}
template <class _Type>
void Splay<_Type>::clear(typename Splay<_Type>::point& input) {
	if (input == nullptr) return;
	clear(input->l);
	clear(input->r);
	input->l = nullptr;
	input->r = nullptr;
	input = nullptr;
}

template <class _Type>
void Splay<_Type>::reverse(int l, int r)
{
	l = findKth(l);
	r = findKth(r + 2);
	Splay(l, 0);
	Splay(r, l);

}
template <class _Type>
typename Splay<_Type>::point Splay<_Type>::findKth(int k)	//查找第k大
{
	if (m_root == nullptr || k > m_tot) return nullptr;
	Splay<_Type>::point cur = m_root;
	while (cur) {
		int l = GetSubCnt(cur->l);
		int r = GetSubCnt(cur->r);
		int m = cur->c;
		if (k <= l) {
			cur = cur->l;
		}
		else if (k <= l + m) {
			return splay(cur);
		}
		else {
			k -= l + m;
			cur = cur->r;
		}
	}
	return nullptr;
}

template <class _Type>
int Splay<_Type>::findRank(_Type x) {
	Splay<_Type>::point cur = m_root;

	int ans = 0;
	while (cur) {
		if (cur->data > x) {
			cur = cur->l;
		}
		else if (cur->data < x) {
			ans += GetSubCnt(cur->l) + cur->c;
			cur = cur->r;
		}
		else break;
	}
	if (cur == nullptr) return ans + 1;

	ans += GetSubCnt(cur->l) + 1;
	splay(cur);
	return ans;
}

template <class _Type>
typename Splay<_Type>::point Splay<_Type>::findPre(_Type x)	//查找x前驱
{
	Splay<_Type>::point cur = m_root, ret = nullptr;
	while (cur) {
		if (cur->data < x)ret = cur, cur = cur->r;
		else cur = cur->l;
	}
	return splay(ret);
}

template <class _Type>
typename Splay<_Type>::point Splay<_Type>::findNext(_Type x)//查找x后继
{
	Splay<_Type>::point cur = m_root, ret = nullptr;
	while (cur) {
		if (cur->data > x)ret = cur, cur = cur->l;
		else cur = cur->r;
	}
	return splay(ret);
}

template <class _Type>
typename Splay<_Type>::point Splay<_Type>::find(_Type x)//查找x
{
	Splay<_Type>::point cur = m_root;
	while (cur) {
		if (cur->data == x) {
			break;
		}
		else if (cur->data > x) {
			cur = cur->l;
		}
		else {
			cur = cur->r;
		}
	}
	return splay(cur);
}

template <class _Type>
typename Splay<_Type>::point Splay<_Type>::insert(_Type input)//插入
{
	if (m_root == nullptr) return m_root = GetNode(std::forward<_Type>(input));
	else {
		Splay<_Type>::point cur = m_root;
		while (cur) {
			if (cur->data == input) {
				cur->c++;
				m_tot++;
				break;
			}
			else if (cur->data > input) {
				if (cur->l)cur = cur->l;
				else {
					cur->l = GetNode(std::forward<_Type>(input));
					cur->l->fa = cur;
					cur = cur->l;
					break;
				}
			}
			else {
				if (cur->r)cur = cur->r;
				else {
					cur->r = GetNode(std::forward<_Type>(input));
					cur->r->fa = cur;
					cur = cur->r;
					break;
				}
			}
		}
		return splay(cur);
	}
}

template <class _Type>
void Splay<_Type>::remove(_Type input)//删除
{
	Splay<_Type>::point p = find(input), l, r;
	if (p == nullptr) return;
	p = splay(p);
	if (--p->c > 0)return;

	if (p->l == nullptr) {
		m_root = p->r;
		if (p->r)p->r->fa = nullptr;
	}
	else if (p->r == nullptr) {
		m_root = p->l;
		if (p->l)p->l->fa = nullptr;
	}
	else {
		//分裂成2棵子树
		l = p->l;
		r = p->r;
		l->fa = r->fa = nullptr;

		m_root = l;
		l = findPre(p->data); //查找左子树的前驱 相当于左子树的最大值 没有右子树
		//合并2棵子树
		l->r = r;
		r->fa = l;
	}
	delete p;
	m_tot--;
	flush(m_root);
}

template <class _Type>
typename Splay<_Type>::point Splay<_Type>::splay(typename  Splay<_Type>::point& input, typename  Splay<_Type>::point target)
{
	if (input == nullptr) return nullptr;

	while (input->fa != target) {
		Splay<_Type>::point fa = input->fa, ffa = fa->fa;
		bool bol = fa->l == input;
		//父节点是目标节点
		if (ffa == target) {
			//左子树
			if (bol) rotate_right(fa);
			else rotate_left(fa);
		}
		else {
			bool bofl = ffa->l == fa;
			//LL
			if (bofl && bol) {
				rotate_right(ffa);
				rotate_right(fa);
			}
			else if (!bofl && !bol) {
				//RR
				rotate_left(ffa);
				rotate_left(fa);
			}
			else if (bofl && !bol) {
				//LR
				rotate_left(fa);
				rotate_right(ffa);
			}
			else {
				//RL
				rotate_right(fa);
				rotate_left(ffa);
			}
		}
	}
	if (target == nullptr)m_root = input;
	return input;
}

template <class _Type>
void  Splay<_Type>::rotate_left(typename Splay<_Type>::point& node)//左旋转
{
	if (node == nullptr || node->r == nullptr) return;
	Splay::point r = node->r;
	Splay::point rl = r->l;

	if (node->fa) {
		if (node->fa->l == node) node->fa->l = r;
		else node->fa->r = r;
	}
	r->fa = node->fa;
	node->fa = r;

	r->l = node;
	node->r = rl;
	if (rl)rl->fa = node;

	node = r;

	flush(node->l);
	flush(node);
}

template <class _Type>
void  Splay<_Type>::rotate_right(typename Splay<_Type>::point& node)//右旋转
{
	if (node == nullptr || node->l == nullptr) return;
	Splay::point l = node->l;
	Splay::point lr = l->r;

	if (node->fa) {
		if (node->fa->l == node) node->fa->l = l;
		else node->fa->r = l;
	}
	l->fa = node->fa;
	node->fa = l;

	l->r = node;
	node->l = lr;
	if (lr)lr->fa = node;

	node = l;
	flush(node->r);
	flush(node);
}


const int maxn = 5e4 + 7,INF = 2147483647;
int arr[maxn], n, m;

class segment {
public:
	typedef long long int ll;
	segment() {
		
	}
	void build(int l, int r, int idx) {
		for (int i = l; i <= r; i++) {
			root[idx].insert(arr[i]);
		}
		if (l == r) return;
		int mid = (l + r) >> 1;
		int lson = idx * 2, rson = lson + 1;
		build(l, mid, lson);
		build(mid + 1, r, rson);

	}


	void _modify(int l, int r, int idx, int cl, int cr, int x,int y) {
		if (l > cr || r < cl) return;
		root[idx].remove(arr[x]);
		root[idx].insert(y);
		if (l == r) return;

		int mid = (l + r) >> 1, lson = idx * 2, rson = lson + 1;

		if (mid >= cl) _modify(l, mid, lson, cl, cr, x,y);
		if (mid < cr) _modify(mid + 1, r, rson, cl, cr, x,y);

	}
	void modify(int x, int y) {
		_modify(1, n, 1, x, x,x, y);
		arr[x] = y;
	}

	int queryrank(int p, int l, int r, int cl,int cr,int k) // query the highest rank of value 'k'
	{	
		if (l > cr || r < cl) return 0;
		if (l >= cl && r <= cr) {
			return root[p].findRank(k) - 1;
		}
		
		int mid = (l + r) >> 1;
		
		return queryrank(p * 2, l,mid, cl, cr, k) + queryrank(p * 2 + 1,mid + 1,r, cl, cr, k);
	}

	int querynum(int u, int v, int k) // query the value of kth num
	{
		int l = 0, r = 1e8;
		while (l < r)
		{
			int mid = (l + r + 1) / 2;
			int c = queryrank(1, 1, n, u, v, mid);
			if (c < k)
				l = mid;
			else
				r = mid - 1;
		}
		return r;
	}
	int querypre(int p, int l, int r,int cl,int cr, int k)
	{
		if (l > cr || r < cl) return -INF;
			
		if (l >= cl && r <= cr) {
			AVL<int>::point ret = root[p].findPre(k);
			if(ret != nullptr)return ret->data;
			return -INF;
		}
		int mid = (l + r) >> 1;
		return max(querypre(p * 2, l,mid,cl, cr, k), querypre(p * 2 + 1,mid + 1,r, cl, cr, k));
	}
	int querysuf(int p, int l, int r,int cl,int cr, int k)
	{
		if (l > cr || r < cl)return INF;
			
		if (l >= cl && r <= cr) {
			AVL<int>::point ret = root[p].findNext(k);
			if(ret != nullptr)return ret->data;
			return INF;
		}
		int mid = (l + r) >> 1;
		return min(querysuf(p * 2,l,mid, cl, cr, k), querysuf(p * 2 + 1,mid + 1,r, cl, cr, k));
	}

public:
	AVL<int> root[maxn << 3];
};
segment tree;

int main()
{
	//ReadFile();
	read(n);
	read(m);
	for (int i = 1; i <= n; i++)
		read(arr[i]);

	tree.build(1, n,1);
	int opt, x, y;
	for (int i = 0; i < m; ++i)
	{
		read(opt);
		if (opt == 3)
		{
			read(x); read(y);
			tree.modify(x,y);
		}
		else
		{
			int l, r, k;
			read(l); read(r); read(k);
			if (opt == 1) printf("%d\n", tree.queryrank(1, 1, n, l, r, k) + 1);
			else if (opt == 2) printf("%d\n", tree.querynum(l, r, k));
			else if (opt == 4) printf("%d\n", tree.querypre(1, 1, n, l, r, k));
			else printf("%d\n", tree.querysuf(1, 1, n, l, r, k));
		}
	}
	return 0;
}

你可能感兴趣的:(树论,c++)