将红黑树作为一个基础的类模板,通过给这个类模板传递不同的参数,从而控制它所实现的容器。
最主要的点是用自己的map和set通过传递不同的模板参数控制红黑树第二个模板参数 T 来确定传入的到底是 Key 还是 pair
红黑树的第二个参数T,通过传入参数的不同,控制红黑树中到底存储什么类型的变量。
给红黑树传的第三个模板参数 KeyOfT 是一个类,这个类中重载了 operator(),它实例化出的对象可以当做函数来使用,并且这个函数可以实现多种不同数据类型的大小比较,通过传进来类的不同,来使用不同的比较方式。
如图,KeyOfT示例化出的对象kot来进行树中的大小比较,当传 SetKeyOfT 时,就使用 key 作为比较依据;当传 MapKeyOfT 时,就使用 kv.first 作为大小比较依据。
首先,需要写一个迭代器的类模板,里面有迭代器变量主要需要使用的一些功能,如++,->,*,!=,==等等这些,通过传递不同的参数,来构造不同类型的迭代器。
迭代器的内部成员变量就是一个红黑树结点类型的指针,通过重载++,--,->等运算符作为成员函数。而模板参数传 Ref 和 Ptr 的原因是:当构造的是 const 类型或 非const 类型的迭代器时,编译器可以自己推导出返回值的类型。因为普通迭代器 iterator 和const类型的迭代器 const_iterator 不是同一个类型!
在红黑树中根据模板参数的不同重定义两个不同的迭代器:普通迭代器和 const 迭代器。当使用普通迭代器时,相当于给迭代器模板传
stl中 set 里的 key 不允许改变,因此无论是 iterator 还是 const_iterator,都需要使用红黑树中定义的 const_iterator 迭代器来初始化。
而map中,key不允许改变,T却可以改变,直接在模板参数中加 const 修饰K,普通迭代器用iterator,const 迭代器用 const_iterator修饰。
将insert 的返回值设置为 pair 类型,第一个 first 参数为 Node* 类型,因为无论是普通迭代器还是 const类型的迭代器,都是用 Node* 类型的变量来初始化的,只要是相近类型,pair的拷贝构造函数都能进行构造推导,推导为正确的类型!
#pragma once
#pragma once
#include
#include
#include
using namespace std;
typedef enum Color
{
RED, // 红色
BLACK // 黑色
}Color;
template
struct RBTreeNode
{
RBTreeNode(const T& data)
:_right(nullptr)
, _left(nullptr)
, _parent(nullptr)
, _col(RED)
, _data(data)
{}
RBTreeNode* _right;
RBTreeNode* _left;
RBTreeNode* _parent;
// 颜色:枚举类型
Color _col;
T _data;
};
template
struct __TreeIterator
{
typedef RBTreeNode Node;
typedef __TreeIterator self;
Node* _node;
__TreeIterator(Node* node)
:_node(node)
{}
self& operator--()
{
if (_node->_left)
{
Node* cur = _node->_left;
while (cur->_right)
{
cur = cur->_right;
}
_node = cur;
}
else
{
// 等于空,说明当前结点已经访问结束了。
// 向上回溯,直到孩子是父亲左节点的孩子
Node* cur = _node;
Node* parent = _node->_parent;
while (parent && cur == parent->_left)
{
cur = parent;
parent = parent->_parent;
}
_node = parent;
}
return *this;
}
self& operator++()
{
if (_node->_right)
{
Node* cur = _node->_right;
while (cur->_left)
{
cur = cur->_left;
}
_node = cur;
}
else
{
// 等于空,说明当前结点已经访问结束了。
// 向上回溯,直到孩子是父亲左节点的孩子
Node* cur = _node;
Node* parent = _node->_parent;
while (parent && cur == parent->_right)
{
cur = parent;
parent = parent->_parent;
}
_node = parent;
}
return *this;
}
bool operator==(const self& it)
{
return _node == it._node;
}
bool operator!=(const self& it)
{
return _node != it._node;
}
Ptr operator->()
{
return &_node->_data;
}
Ref operator*()
{
return _node->_data;
}
};
template
class RBTree
{
typedef RBTreeNode Node;
public:
typedef __TreeIterator iterator;
typedef __TreeIterator const_iterator;
KeyOfT kot;
iterator begin()
{
Node* cur = _root;
while (cur->_left)
{
cur = cur->_left; /*_t = cur; // 这个 _t 是根节点,不能随便变*/
}
return iterator(cur); //return *this;迭代器这里不能返回引用
}
iterator end()
{
return iterator(nullptr);
}
const_iterator begin() const
{
Node* cur = _root;
while (cur->_left)
{
cur = cur->_left;
}
return const_iterator(cur);
}
const_iterator end() const
{
return const_iterator(nullptr);
}
pair Insert(const T& data)
{
if (_root == nullptr)
{
_root = new Node(data);
_root->_col = BLACK;
return make_pair(_root, true);
}
//走到此处,说明根不为空
Node* parent = nullptr;
Node* cur = _root;
while (cur)
{
if (kot(cur->_data) < kot(data))
{
parent = cur;
cur = cur->_right;
}
else if (kot(cur->_data)> kot(data))
{
parent = cur;
cur = cur->_left;
}
else // 相等
{
return make_pair(cur, false);
}
}
cur = new Node(data); // 初始化的时候就是红色结点
cur->_parent = parent;
// 调整父子之间的关系
if (kot(parent->_data) < kot(data))
parent->_right = cur;
else if (kot(parent->_data) > kot(data))
parent->_left = cur;
Node* grandparent = parent->_parent;
while (parent && parent->_col == RED)
{
grandparent = parent->_parent;
Node* uncle = nullptr;
// 调整父亲和叔叔与爷爷的关系
if (kot(parent->_data) > kot(grandparent->_data))
{
grandparent->_right = parent;
uncle = grandparent->_left;
}
else if (kot(parent->_data) < kot(grandparent->_data))
{
grandparent->_left = parent;
uncle = grandparent->_right;
}
// 新插入结点的父亲是红色结点,需要调整
// 调整父亲和叔叔的左右
if (uncle && uncle->_col == RED) // 叔叔存在并且为红色
{
grandparent->_col = RED;
uncle->_col = parent->_col = BLACK;
cur = grandparent;
parent = cur->_parent;
}
else if (uncle == nullptr || uncle->_col == BLACK) // 叔叔不存在或者存在且为黑色,需要调整加变色
{
// 旋转都是将不均衡变为相对均衡,然后旋转后的父节点变为黑色,分到叔叔那边的结点变为红色
// 因为这种情况,爷爷本来就是黑色
if (parent == grandparent->_left)
{
if (cur == parent->_left)
{
RotateR(grandparent);
parent->_col = BLACK;
grandparent->_col = RED;
}
else if (cur == parent->_right)
{
RotateL(parent);
RotateR(grandparent);
cur->_col = BLACK;
grandparent->_col = RED;
}
}
else
{
if (cur == parent->_left)
{
RotateR(parent);
RotateL(grandparent);
grandparent->_col = RED;
cur->_col = BLACK;
}
else if (cur == parent->_right)
{
RotateL(grandparent);
grandparent->_col = RED;
parent->_col = BLACK;
}
}
break;
}
}
_root->_col = BLACK;
return make_pair(cur, true);
}
bool IsValidRBTree()
{
Node* pRoot = _root;
// 空树也是红黑树
if (nullptr == pRoot)
return true;
// 检测根节点是否满足情况
if (BLACK != pRoot->_col)
{
cout << "违反红黑树性质二:根节点必须为黑色" << endl;
return false;
}
// 获取任意一条路径中黑色节点的个数
size_t blackCount = 0;
Node* pCur = pRoot;
while (pCur)
{
if (BLACK == pCur->_col)
blackCount++;
pCur = pCur->_left;
}
// 检测是否满足红黑树的性质,k用来记录路径中黑色节点的个数
size_t k = 0;
return _IsValidRBTree(pRoot, k, blackCount);
}
bool _IsValidRBTree(Node* pRoot, size_t k, const size_t blackCount)
{
//走到null之后,判断k和black是否相等
if (nullptr == pRoot)
{
if (k != blackCount)
{
cout << "违反性质四:每条路径中黑色节点的个数必须相同" << endl;
return false;
}
return true;
}
// 统计黑色节点的个数
if (BLACK == pRoot->_col)
k++;
// 检测当前节点与其双亲是否都为红色
Node* pParent = pRoot->_parent;
if (pParent && RED == pParent->_colo && RED == pRoot->_col)
{
cout << "违反性质三:没有连在一起的红色节点" << endl;
return false;
}
return _IsValidRBTree(pRoot->_left, k, blackCount) &&
_IsValidRBTree(pRoot->_right, k, blackCount);
}
void Order()
{
_Order(_root);
cout << endl;
}
void _Order(Node* root)
{
if (root == nullptr)
return;
_Order(root->_left);
cout << root->_kv.first << " ";
_Order(root->_right);
}
void RotateR(Node* parent)
{
Node* subL = parent->_left;
Node* subLR = subL->_right;
Node* parentParent = parent->_parent;
subL->_right = parent;
parent->_parent = subL;
parent->_left = subLR;
if (subLR)
subLR->_parent = parent;
if (parent == _root)
{
_root = subL;
subL->_parent = nullptr;
}
else
{
if (parentParent->_left == parent)
{
parentParent->_left = subL;
}
else if (parentParent->_right == parent)
{
parentParent->_right = subL;
}
subL->_parent = parentParent;
}
}
void RotateL(Node* parent)
{
Node* subR = parent->_right;
Node* subRL = subR->_left;
Node* parentParent = parent->_parent;
subR->_left = parent;
parent->_parent = subR;
if (subRL)
subRL->_parent = parent;
parent->_right = subRL;
if (parent == _root)
{
_root = subR;
subR->_parent = nullptr;
}
else
{
// 先找到 parent 是父亲的左子树还是右子树
if (parentParent->_left == parent) // 左子树
{
parentParent->_left = subR;
}
else if (parentParent->_right == parent) // 右子树
{
parentParent->_right = subR;
}
subR->_parent = parentParent;
}
}
private:
Node* _root = nullptr;
};
#pragma once
#include"RBTree.h"
namespace zyb
{
template
class map
{
public:
struct MapKeyOfT
{
const K& operator()(const pair& kv)
{
return kv.first;
}
};
typedef typename RBTree, MapKeyOfT>::iterator iterator;
typedef typename RBTree, MapKeyOfT>::const_iterator const_iterator;
iterator begin()
{
return _t.begin();
}
iterator end()
{
return _t.end();
}
const_iterator begin() const
{
return _t.begin();
}
const_iterator end() const
{
return _t.end();
}
T& operator[](const K& key)
{
pair ret = insert(make_pair(key, T()));
return ret.first->second;
}
pair insert(const pair& kv)
{
return _t.Insert(kv);
}
private:
RBTree, MapKeyOfT> _t;
};
}
#pragma once
#include"RBTree.h"
namespace zyb
{
template
class set
{
public:
struct SetKeyOfT
{
const K& operator()(const K& key)
{
return key;
}
};
typedef typename RBTree::const_iterator iterator;
typedef typename RBTree::const_iterator const_iterator;
pair insert(const K& key)
{
return _t.Insert(key);
}
const_iterator begin() const
{
return _t.begin();
}
const_iterator end() const
{
return _t.end();
}
private:
RBTree _t;
};
}
可以使用以下用例对功能进行测试:
#include"RBTree.h"
#include"MyMap.h"
#include"MySet.h"
void test1(const zyb::set& s1)
{
zyb::set::const_iterator it = s1.begin();
while (it != s1.end())
{
cout << *it << " ";
++it;
}
cout << endl;
}
void test2(const zyb::map& m1)
{
zyb::map::const_iterator it = m1.begin();
while (it != m1.end())
{
cout << it->first << ":" << it->second << endl;
++it;
}
cout << endl;
}
int main()
{
zyb::set s1;
s1.insert(1);
s1.insert(2);
s1.insert(3);
s1.insert(3);
s1.insert(5);
s1.insert(6);
zyb::map m1;
m1.insert(make_pair("zyb", 1));
m1.insert(make_pair("jn", 1));
m1.insert(make_pair("yw", 1));
test1(s1);
test2(m1);
return 0;
}