set
/* RBTree.h*/
#pragma once
#include
namespace sjy
{
enum Color
{
RED,
BLACK
};
template <typename T>
struct RBTreeNode
{
RBTreeNode(const T& data = T())
:_left(nullptr)
, _right(nullptr)
, _parent(nullptr)
, _data(data)
, _col(RED)
{}
RBTreeNode<T>* _left;
RBTreeNode<T>* _right;
RBTreeNode<T>* _parent;
T _data;
Color _col;
};
template <typename T, typename Ref, typename Ptr>
struct __TreeIterator
{
typedef RBTreeNode<T> Node;
typedef __TreeIterator<T, Ref, Ptr> self;
__TreeIterator(Node* node)
:_node(node)
{}
self& operator++()
{
//node右边是nullptr, 不是header
if (_node->_right == nullptr)
{
Node* cur = _node;
Node* parent = cur->_parent;
while (parent != nullptr)
{
if (parent->_left == cur)
{
_node = parent;
break;
}
else
{
cur = parent;
parent = parent->_parent;
}
}
}
//node右边是一个正常节点
else if (_node->_right->_right != _node)
{
Node* cur = _node->_right;
while (cur->_left != nullptr)
{
cur = cur->_left;
}
_node = cur;
}
//node右边是header
else
{
_node = _node->_right;
}
return *this;
}
self& operator--()
{
//node左边是nullptr, 不是header
if (_node->_left == nullptr)
{
Node* cur = _node;
Node* parent = cur->_parent;
while (parent != nullptr)
{
if (parent->_right == cur)
{
_node = parent;
break;
}
else
{
cur = parent;
parent = parent->_parent;
}
}
}
//node左边是一个正常节点
else if (_node->_left->_left != _node)
{
Node* cur = _node->_left;
while (cur->_right != nullptr)
{
cur = cur->_right;
}
_node = cur;
}
//node左边是header
else
{
_node = _node->_left;
}
return *this;
}
Ref operator*()
{
return _node->_data;
}
Ptr operator->()
{
return &(_node->_data);
}
bool operator==(const self& other)
{
return _node == other._node;
}
bool operator!=(const self& other)
{
return _node != other._node;
}
/*成员变量*/
Node* _node;
};
template <typename K, typename T, typename KeyOfT>
class RBTree
{
typedef RBTreeNode<T> Node;
public:
//默认成员函数
RBTree()
:_root(nullptr)
, _header()
, _size(0)
{}
~RBTree()
{
clear();
}
//迭代器相关
typedef __TreeIterator<T, T&, T*> iterator;
typedef __TreeIterator<T, const T&, const T*> const_iterator;
iterator begin()
{
return _header._left;
}
iterator end()
{
return &_header;
}
const_iterator begin() const
{
return _root;
}
const_iterator end() const
{
return &_header;
}
//旋转相关
void RotateL(Node* parent)
{
Node* cur = parent->_right;
Node* grandparent = parent->_parent;
/*修改链接关系*/
//grandparent
if (parent == _root)
{
_root = cur;
}
else
{
if (grandparent->_left == parent)
{
grandparent->_left = cur;
}
else
{
grandparent->_right = cur;
}
}
//cur左子树头节点
if (cur->_left != nullptr)
{
cur->_left->_parent = parent;
}
//parent
parent->_parent = cur;
parent->_right = cur->_left;
//cur
cur->_left = parent;
cur->_parent = grandparent;
}
void RotateR(Node* parent)
{
Node* cur = parent->_left;
Node* grandparent = parent->_parent;
/*修改链接关系*/
//grandparent
if (parent == _root)
{
_root = cur;
}
else
{
if (grandparent->_left == parent)
{
grandparent->_left = cur;
}
else
{
grandparent->_right = cur;
}
}
//cur右子树头节点
if (cur->_right != nullptr)
{
cur->_right->_parent = parent;
}
//parent
parent->_parent = cur;
parent->_left = cur->_right;
//cur
cur->_parent = grandparent;
cur->_right = parent;
}
void RotateLR(Node* parent)
{
Node* cur = parent->_left;
RotateL(cur);
RotateR(parent);
}
void RotateRL(Node* parent)
{
Node* cur = parent->_right;
RotateR(cur);
RotateL(parent);
}
//插入
pair<iterator, bool> Insert(const T& data)
{
HeaderFadeAway();
if (_root == nullptr)
{
_root = new Node(data);
_root->_col = BLACK;
_size++;
HeaderComeBack();
return make_pair(_root, true);
}
Node* cur = _root;
Node* parent = nullptr;
KeyOfT kot;
while (cur != nullptr)
{
if (kot(data) < kot(cur->_data))
{
parent = cur;
cur = cur->_left;
}
else if (kot(data) > kot(cur->_data))
{
parent = cur;
cur = cur->_right;
}
else
{
HeaderComeBack();
return make_pair(cur, false);
}
}
Node* newnode = new Node(data);
cur = newnode;
if (kot(data) < kot(parent->_data))
{
newnode->_parent = parent;
parent->_left = newnode;
}
else if (kot(data) > kot(parent->_data))
{
newnode->_parent = parent;
parent->_right = newnode;
}
while (parent != nullptr)
{
Node* grandparent = parent->_parent;
if (grandparent == nullptr)
{
parent->_col = BLACK;
_root = parent;
break;
}
else if (parent->_col == BLACK)
{
break;
}
else if (parent->_col == RED)
{
if (grandparent->_left == parent && grandparent->_right == nullptr) // 1
{
if (parent->_left == cur)
{
RotateR(grandparent);
parent->_col = RED;
grandparent->_col = BLACK;
cur->_col = BLACK;
cur = parent;
parent = parent->_parent;
}
else
{
RotateLR(grandparent);
cur->_col = RED;
grandparent->_col = BLACK;
parent->_col = BLACK;
parent = cur->_parent;
}
}
else if (grandparent->_right == parent && grandparent->_left == nullptr) // 1
{
if (parent->_left == cur)
{
RotateRL(grandparent);
cur->_col = RED;
grandparent->_col = BLACK;
parent->_col = BLACK;
parent = cur->_parent;
}
else
{
RotateL(grandparent);
parent->_col = RED;
grandparent->_col = BLACK;
cur->_col = BLACK;
cur = parent;
parent = parent->_parent;
}
}
else if (grandparent->_left->_col == RED && grandparent->_right->_col == RED) // 1
{
grandparent->_col = RED;
grandparent->_left->_col = BLACK;
grandparent->_right->_col = BLACK;
cur = grandparent;
parent = grandparent->_parent;
}
else if (grandparent->_left == parent && grandparent->_right->_col == BLACK)
{
if (parent->_left == cur)
{
RotateR(grandparent);
parent->_col = BLACK;
grandparent->_col = RED;
}
else
{
RotateLR(grandparent);
cur->_col = BLACK;
grandparent->_col = RED;
}
break;
}
else if (grandparent->_right == parent && grandparent->_left->_col == BLACK)
{
if (parent->_left == cur)
{
RotateRL(grandparent);
cur->_col = BLACK;
grandparent->_col = RED;
}
else
{
RotateL(grandparent);
parent->_col = BLACK;
grandparent->_col = RED;
}
break;
}
}
}
if (_root->_col == RED)
{
_root->_col = BLACK;
}
_size++;
HeaderComeBack();
return make_pair(newnode, true);
}
//Header在插入前后的变化
void HeaderFadeAway()
{
if (_header._left != nullptr)
{
(_header._left)->_left = nullptr;
}
if (_header._right != nullptr)
{
(_header._right)->_right = nullptr;
}
_header._left = nullptr;
_header._right = nullptr;
}
void HeaderComeBack()
{
if (_root == nullptr)
{
_header._left = nullptr;
_header._right = nullptr;
}
else
{
Node* leftmin = _root;
Node* rightmax = _root;
while (leftmin->_left != nullptr)
{
leftmin = leftmin->_left;
}
while (rightmax->_right != nullptr)
{
rightmax = rightmax->_right;
}
_header._left = leftmin;
_header._right = rightmax;
leftmin->_left = &_header;
rightmax->_right = &_header;
}
}
//查找
iterator Find(const K& key)
{
if (empty())
{
return end();
}
else
{
KeyOfT kot;
Node* cur = _root;
while (cur != nullptr && cur != &_header)
{
if (kot(cur->_data) == key)
{
return cur;
}
else if (kot(cur->_data) > key)
{
cur = cur->_left;
}
else if (kot(cur->_data) < key)
{
cur = cur->_right;
}
}
return end();
}
}
//其他
size_t size() const
{
return _size;
}
bool empty() const
{
return _size == 0;
}
void clear()
{
_clear(_root);
}
private:
void _clear(Node* root)
{
if (root == nullptr || root == &_header)
{
return;
}
_clear(root->_left);
_clear(root->_right);
delete root;
_size--;
_root = nullptr;
_header._left = _header._right = nullptr;
}
Node* _root;
Node _header;
size_t _size;
};
}
/*Myset.h*/
#include "RBTree.h"
namespace sjy
{
template <typename K>
class set
{
struct SetKeyOfT
{
const K& operator()(const K& key)
{
return key;
}
};
public:
//迭代器相关
typename typedef RBTree<K, K, SetKeyOfT>::iterator iterator;
typename typedef RBTree<K,K, SetKeyOfT>::const_iterator const_iterator;
iterator begin()
{
return _tree.begin();
}
iterator end()
{
return _tree.end();
}
const_iterator begin() const
{
return _tree.begin();
}
const_iterator end() const
{
return _tree.end();
}
//插入
pair<iterator, bool> Insert(const K& key)
{
return _tree.Insert(key);
}
//查找
iterator Find(const K& key)
{
return _tree.Find(key);
}
//其他
size_t size() const
{
return _tree.size();
}
bool empty() const
{
return _tree.empty();
}
void clear()
{
_tree.clear();
}
private:
RBTree<K, K, SetKeyOfT> _tree;
};
}