AVL树又被叫做平衡二叉搜索树、平衡二叉树。AVL是其发明者的首字母缩写。
这篇文章中,AVLTreeMap 类集成了 java.util.Map 接口,并利用 AVL 树结构实现了 Map 接口的所有方法。本文还给出了测试代码。
当我按照从小到大或者从大到小的顺序向二叉查找树插入节点,二叉查找树就会退化成一个链表。这是二叉查找树的最差情况。搜索、插入、删除的最差效率都是 O(N)。这样就失去了用二叉查找树优化查找方法的意义。
就算不是最坏情况,也会出现非常不平衡的树,造成查找效率大于 O(logN) 小于 O(N) 。注意这里 logN 是以2为底N的对数。
AVLTreeMap.java
package zhangchao.avl;
import java.util.*;
/**
* 利用AVL树,也就是平衡二叉树来实现map
* @author zhangchao
* @param 键
* @param 值
*/
public class AVLTreeMap<K,V> implements Map<K, V>{
// 根节点
private Node<K, V> root = null;
private Comparator<K> comparator;
public AVLTreeMap(Comparator<K> comparator) {
this.comparator = comparator;
}
@Override
public int size() {
if (null == root) {
return 0;
}
int size = 0;
// 当前层的节点列表
List<Node<K, V>> currentLevel = new ArrayList<>();
currentLevel.add(root);
while (!currentLevel.isEmpty()) {
size += currentLevel.size();
// 下一层的节点列表
List<Node<K, V>> nextLevel = new ArrayList<>();
for (Node<K, V> tmpNode : currentLevel) {
if (null != tmpNode.leftChild) {
nextLevel.add(tmpNode.leftChild);
}
if (null != tmpNode.rightChild) {
nextLevel.add(tmpNode.rightChild);
}
}
currentLevel.clear();
currentLevel.addAll(nextLevel);
}
return size;
}
@Override
public boolean isEmpty() {
return (null == root);
}
@Override
public boolean containsKey(Object keyObj) {
if (null == root) {
return false;
}
K key = (K)keyObj;
Node<K, V> current = this.root;
while(null != current) {
int compareResult = this.comparator.compare(key, current.key);
if (compareResult < 0) {
current = current.leftChild;
} else if (compareResult > 0) {
current = current.rightChild;
} else {
return true;
}
}
return false;
}
@Override
public boolean containsValue(Object value) {
if (null == this.root) {
return false;
}
List<Node<K, V>> nodeList = this.nodeList();
for (Node<K, V> node : nodeList) {
if (null == value && null == node.value) {
return true;
}
if (null != value && value.equals(node.value)) {
return true;
}
}
return false;
}
@Override
public V get(Object keyObj) {
if (null == this.root) {
return null;
}
K key = (K)keyObj;
Node<K, V> current = this.root;
while(null != current) {
int compareResult = this.comparator.compare(key, current.key);
if (compareResult < 0) {
current = current.leftChild;
} else if (compareResult > 0) {
current = current.rightChild;
} else {
return current.value;
}
}
return null;
}
/**
* 右旋操作
* @param parent 爷爷节点
* @param y 父级节点
* @param x 子级节点
*/
private void rotateRight(Node<K, V> parent, Node<K, V> y, Node<K, V> x) {
y.leftChild = x.rightChild;
x.rightChild = y;
if (null == parent) {
this.root = x;
} else {
// 判断原来的y是parent的左子节点还是右子节点。
if (null != parent.leftChild && 0 == this.comparator.compare(parent.leftChild.key, y.key)) {
parent.leftChild = x;
} else if (null != parent.rightChild && 0 == this.comparator.compare(parent.rightChild.key, y.key)) {
parent.rightChild = x;
}
}
y.height = this.calHeight(y);
x.height = this.calHeight(x);
if (null != parent) {
parent.height = this.calHeight(parent);
}
}
/**
* 左旋操作
* @param parent 爷爷节点
* @param y 父级节点
* @param x 子级节点
*/
private void rotateLeft(Node<K, V> parent, Node<K, V> y, Node<K, V> x) {
y.rightChild = x.leftChild;
x.leftChild = y;
if (null == parent) {
this.root = x;
} else {
// 判断原来的y是parent的左子节点还是右子节点。
if (null != parent.leftChild && 0 == this.comparator.compare(parent.leftChild.key, y.key)) {
parent.leftChild = x;
} else if (null != parent.rightChild && 0 == this.comparator.compare(parent.rightChild.key, y.key)) {
parent.rightChild = x;
}
}
y.height = this.calHeight(y);
x.height = this.calHeight(x);
if (null != parent) {
parent.height = this.calHeight(parent);
}
}
@Override
public V put(K key, V value) {
if (null == this.root) {
this.root = new Node<>();
this.root.key = key;
this.root.value = value;
this.root.height = 1;
return null;
}
// 如果key是全新的,保存所有的父亲节点。
List<Node<K, V>> linkList = new ArrayList<>();
// 如果key是全新的,这个变量是新节点的父亲节点。
Node<K, V> parent = null;
Node<K, V> current = root;
int compareResult = 0;
while (null != current) {
compareResult = this.comparator.compare(key, current.key);
if (compareResult < 0) {
parent = current;
linkList.add(parent);
current = current.leftChild;
} else if (compareResult > 0) {
parent = current;
linkList.add(parent);
current = current.rightChild;
} else {
// 有相等的key,直接设置值就可以了。
V oldValue = current.value;
current.value = value;
return oldValue;
}
}
Node<K, V> newItem = new Node<K, V>();
newItem.key = key;
newItem.value = value;
newItem.height = 1;
if (compareResult < 0) {
parent.leftChild = newItem;
} else if (compareResult > 0) {
parent.rightChild = newItem;
}
// 更新祖先节点的高度
final int size = linkList.size();
for (int i = size - 1; i >= 0; i--) {
Node<K, V> item = linkList.get(i);
item.height = calHeight(item);
}
linkList.add(newItem);
int parentSize = linkList.size();
for (int i = parentSize - 1; i >= 0; i--) {
// 当前节点
Node<K, V> z = linkList.get(i);
// z的父节点,如果z是根节点,那么就是null。
Node<K, V> z_parent = null;
if (i > 0) {
z_parent = linkList.get(i - 1);
}
int leftHeight = this.calHeight(z.leftChild);
int rightHeight = this.calHeight(z.rightChild);
int balance = leftHeight - rightHeight;
if (balance > 1) { // LL 或 LR
Node<K, V> y = z.leftChild;
Node<K, V> x = linkList.get(i + 2);
boolean isLL = (null != y.leftChild && this.comparator.compare(y.leftChild.key, x.key) == 0);
boolean isLR = (null != y.rightChild && this.comparator.compare(y.rightChild.key, x.key) == 0);
if (isLL) { // LL 右旋
this.rotateRight(z_parent, z, y);
}
else if (isLR) { // LR
// y和x之间左旋
this.rotateLeft(z, y, x);
// z和x之间右旋
this.rotateRight(z_parent, z, x);
}
break; // 停止for循环
} else if (balance < -1) { // RR 或 RL
Node<K, V> y = z.rightChild;
Node<K, V> x = linkList.get(i + 2);
boolean isRR = (null != y.rightChild && this.comparator.compare(y.rightChild.key, x.key) == 0);
boolean isRL = (null != y.leftChild && this.comparator.compare(y.leftChild.key, x.key) == 0);
if (isRR) {
this.rotateLeft(z_parent, z, y);
} else if (isRL) {
// y和x之间右旋
this.rotateRight(z, y, x);
// z和x之间左旋
this.rotateLeft(z_parent, z, x);
}
break; // 停止for循环
}
}
// 更新祖先节点高度
for (int i = parentSize - 1; i >= 0; i--) {
Node<K, V> item = linkList.get(i);
item.height = calHeight(item);
}
return null;
}
private List<Node<K,V>> getNodeAndParent(K key, List<Node<K, V>> parents) {
if (null == this.root) {
return null;
}
Node<K, V> parent = null;
Node<K, V> current = this.root;
while(null != current) {
int compareResult = this.comparator.compare(key, current.key);
if (compareResult < 0) {
parent = current;
if (null != parents) {
parents.add(parent);
}
current = current.leftChild;
} else if (compareResult > 0) {
parent = current;
if (null != parents) {
parents.add(parent);
}
current = current.rightChild;
} else {
List<Node<K, V>> result = new ArrayList<>();
result.add(current);
result.add(parent);
return result;
}
}
return null;
}
private K deleteAsBST(Node<K, V> node, Node<K, V> parent) {
K endKey = null;
// 叶子节点
if (null == node.leftChild && null == node.rightChild) {
if (node == parent.leftChild) {
parent.leftChild = null;
} else {
parent.rightChild = null;
}
return parent.key;
}
// 左子节点为空,只有右子节点
else if (null == node.leftChild && null != node.rightChild) {
if (node == this.root) {
this.root = node.rightChild;
} else if (node == parent.leftChild) {
parent.leftChild = node.rightChild;
} else if (node == parent.rightChild) {
parent.rightChild = node.rightChild;
}
endKey = node.rightChild.key;
node.rightChild = null;
return endKey;
}
// else 包含两种情况:
// 1.左子节点不为空,右子为空
// 2.左子节点不为空,右子不为空
// 要删除的节点的左子树中,找出最大节点。
Node<K, V> current = node.leftChild;
Node<K, V> currentParent = node;
while (null != current.rightChild) {
currentParent = current;
current = current.rightChild;
}
// 把current从原位置删除
if (current == currentParent.leftChild) {
currentParent.leftChild = current.leftChild;
} else if (current == currentParent.rightChild) {
currentParent.rightChild = current.leftChild;
}
// 让current取代node的位置
if (node == this.root) {
this.root = current;
} else if (node == parent.leftChild) {
parent.leftChild = current;
} else {
parent.rightChild = current;
}
current.leftChild = node.leftChild;
current.rightChild = node.rightChild;
node.leftChild = null;
node.rightChild = null;
if (null == current.leftChild) {
return current.key;
} else {
Node<K, V> p1 = current.leftChild;
while (null != p1.rightChild) {
p1 = p1.rightChild;
}
return p1.key;
}
}
@Override
public V remove(Object keyObj) {
// 空map,不执行删除操作。
if (null == this.root) {
return null;
}
K key = (K)keyObj;
// 只有根节点的情况
if (null == this.root.leftChild && null == this.root.rightChild) {
if (this.comparator.compare(key ,this.root.key) == 0) {
V v = this.root.value;
this.root = null;
return v;
} else {
return null;
}
}
// 不包含key就返回null
List<Node<K, V>> nodeAndParent = this.getNodeAndParent(key, new ArrayList<>());
// map中没有对应的key,不执行删除操作。
if (null == nodeAndParent || nodeAndParent.isEmpty()) {
return null;
}
Node<K, V> node = nodeAndParent.get(0); // 要删除的节点
V result = node.value;
Node<K, V> parent = nodeAndParent.get(1); // 要删除的节点的父亲节点
// 按照二叉搜索树(BST)的方式删除节点。返回结束节点的键。
K endKey = this.deleteAsBST(node, parent);
// 包含所有可能改动过高度的节点的列表。
// 顺序是从根节点向下。
// 替换了已删除节点位置的节点称为替换节点。
// pathList的内容有以下三种情况:
// 1. 叶子节点,pathList包含根节点到父节点。
// 2. 没有左子节点,只有右子节点,pathList包含根节点到替换节点。
// 3. 有左子节点,pathList包含根节点到替换节点,再加上替换节点到替换节点左子树最大节点。
List<Node<K, V>> pathList = new ArrayList<>();
List<Node<K,V>> endKeyResult = this.getNodeAndParent(endKey, pathList);
pathList.add(endKeyResult.get(0));
// 因为可能加入了节点,所以要重新计算 parents 的长度
int size = pathList.size();
for (int i = size - 1; i >= 0; i--) {
Node<K, V> z_parent = i > 0 ? pathList.get(i - 1) : null;
Node<K, V> z = pathList.get(i);
// 更新高度
z.height = this.calHeight(z);
if (null != z_parent) {
z_parent.height = this.calHeight(z_parent);
}
int leftHeight = calHeight(z.leftChild);
int rightHeight = calHeight(z.rightChild);
int balance = leftHeight - rightHeight;
if (balance > 1) {
Node<K, V> y = z.leftChild;
Node<K, V> x = null;
int y_leftHeight = calHeight(y.leftChild);
int y_rightHeight = calHeight(y.rightChild);
if (y_leftHeight >= y_rightHeight) {
// LL
x = y.leftChild;
// z和y之间右旋
this.rotateRight(z_parent, z, y);
} else {
// LR
x = y.rightChild;
// y和x之间左旋
this.rotateLeft(z, y, x);
// z和x之间右旋
this.rotateRight(z_parent, z, x);
}
} else if (balance < -1) {
Node<K, V> y = z.rightChild;
Node<K, V> x = null;
int y_leftHeight = calHeight(y.leftChild);
int y_rightHeight = calHeight(y.rightChild);
if (y_leftHeight >= y_rightHeight) {
// RL
x = y.leftChild;
// y和x之间右旋
this.rotateRight(z, y, x);
// z和x之间左旋
this.rotateLeft(z_parent, z, x);
} else {
// RR
x = y.rightChild;
// z和y之间左旋
this.rotateLeft(z_parent, z, y);
}
}
}
return result;
}
// end public V remove(Object keyObj)
@Override
public void putAll(Map<? extends K, ? extends V> m) {
if (null == m) {
return;
}
Set<? extends K> keySet = m.keySet();
for (K key : keySet) {
this.put(key, m.get(key));
}
}
@Override
public void clear() {
this.root = null;
}
private List<Node<K, V>> nodeList() {
if (null == this.root) {
return new ArrayList<Node<K, V>>();
}
List<Node<K, V>> result = new ArrayList<>();
Stack<Node<K, V>> stack = new Stack<>();
Node<K, V> current = this.root;
while(null != current || !stack.isEmpty()) {
while (null != current) {
stack.push(current);
current = current.leftChild;
}
current = stack.pop();
// 放入结果列表中
result.add(current);
current = current.rightChild;
}
return result;
}
@Override
public Set<K> keySet() {
List<Node<K, V>> nodeList = nodeList();
Set<K> set = new TreeSet<>(this.comparator);
for (Node<K, V> node : nodeList) {
set.add(node.key);
}
return set;
}
@Override
public Collection<V> values() {
List<Node<K, V>> nodeList = nodeList();
List<V> result = new ArrayList<>();
for (Node<K,V> node : nodeList) {
result.add(node.value);
}
return result;
}
@Override
public Set<Entry<K, V>> entrySet() {
List<Node<K, V>> nodeList = this.nodeList();
Set<Entry<K, V>> set = new TreeSet<Entry<K, V>>((o1, o2) -> {
Node<K, V> n1 = (Node<K, V>) o1;
Node<K, V> n2 = (Node<K, V>) o2;
return comparator.compare(n1.key, n2.key);
});
for (Node<K,V> node : nodeList) {
set.add(node);
}
return set;
}
private int calHeightForCheck(Node<K,V> node) {
if (null == node) {
return 0;
}
int height = 0;
List<Node<K,V>> currentLevel = new ArrayList<>();
currentLevel.add(node);
while (!currentLevel.isEmpty()) {
height ++;
List<Node<K,V>> nextLevel = new ArrayList<>();
for (Node<K,V> tmpNode : currentLevel) {
if (null != tmpNode.leftChild) {
nextLevel.add(tmpNode.leftChild);
}
if (null != tmpNode.rightChild) {
nextLevel.add(tmpNode.rightChild);
}
}
currentLevel = nextLevel;
}
return height;
}
private void showTree(Node node, Node parent, int level, String prefix) {
if (null == node) {
return;
}
StringBuilder sb = new StringBuilder();
for (int i = 0; i < level; i++) {
sb.append(" ");
}
sb.append(prefix);
sb.append(node.key).append(" ");
if (parent != null) {
sb.append(parent.key);
}
int balance = calHeightForCheck(node.leftChild) - calHeightForCheck(node.rightChild);
sb.append(" ").append(balance);
System.out.println(sb);
level++;
showTree(node.leftChild, node, level, "left : ");
showTree(node.rightChild, node, level, "right: ");
}
/**
* 打印树形结构。
*/
public void showTree() {
if (null == root) {
System.out.println("null");
}
showTree(root, null, 0, "root: ");
}
private void checkTree(Node node, Node parent, int level) {
if (null == node) {
return;
}
int balance = calHeightForCheck(node.leftChild) - calHeightForCheck(node.rightChild);
if (balance < -1 || balance > 1) {
throw new RuntimeException("balance < -1 || balance > 1");
}
level++;
checkTree(node.leftChild, node, level);
checkTree(node.rightChild, node, level);
}
/**
* 检查树是不是符合AVL树的要求
*/
public void checkTree() {
if (null == root) {
return;
}
checkTree(root, null, 0);
}
/**
* 以node为根节点,计算树的高度
* @param node 根节点
* @return 树的高度
*/
private int calHeight(Node<K,V> node) {
if (null == node) {
return 0;
}
int leftHeight = (null == node.leftChild) ? 0 : node.leftChild.height;
int rightHeight = (null == node.rightChild) ? 0 : node.rightChild.height;
return Math.max(leftHeight, rightHeight) + 1;
}
class Node<K,V> implements Entry<K,V> {
K key = null;
V value = null;
int height;
Node<K, V> leftChild;
Node<K, V> rightChild;
@Override
public K getKey() {
return key;
}
@Override
public V getValue() {
return value;
}
@Override
public V setValue(V tmpValue) {
V oldValue = value;
value = tmpValue;
return oldValue;
}
public int getHeight() {
return height;
}
}
}
测试代码 TestAVLTreeMap.java
这里面有和二叉查找树Map的对比。
二叉查找树Map的实现文章:https://blog.csdn.net/zhangchao19890805/article/details/128609922?spm=1001.2014.3001.5502
package zhangchao.avl.test;
import zhangchao.avl.AVLTreeMap;
import zhangchao.bst.BstTreeMap;
import java.util.*;
public class TestAVLTreeMap {
public static void main(String[] args) {
t7();
}
public static void t7() {
int a[] = {20, 10, 21, 22, 5, 15, 1};
Comparator<Integer> comparator = (o1, o2) ->{
if (null == o1 && null == o2) {
return 0;
}
if (null == o1 && null != o2) {
return -1;
}
if (null != o1 && null == o2) {
return 1;
}
return o1 - o2;
};
AVLTreeMap<Integer, String> avlTreeMap = new AVLTreeMap<>(comparator );
for (int key : a) {
avlTreeMap.put(key, "__" + key);
}
avlTreeMap.showTree();
avlTreeMap.remove(20);
System.out.println("\n");
avlTreeMap.showTree();
avlTreeMap.checkTree();
}
public static void t6() {
Comparator<Integer> comparator = (o1, o2) ->{
if (null == o1 && null == o2) {
return 0;
}
if (null == o1 && null != o2) {
return -1;
}
if (null != o1 && null == o2) {
return 1;
}
return o1 - o2;
};
AVLTreeMap<Integer, String> avlTreeMap = new AVLTreeMap<>(comparator );
BstTreeMap<Integer, String> bstTreeMap = new BstTreeMap<>(comparator);
long t1;
long t2;
// 比对插入
System.out.println("insert");
Random r = new Random();
final int MAX = 100000;
List<Integer> list = new ArrayList<>();
for (int i = 0; i < MAX; i++) {
int key = r.nextInt(MAX);
list.add(i);
}
t1 = System.currentTimeMillis();
for (int key : list) {
avlTreeMap.put(key, "__" + key);
}
t2 = System.currentTimeMillis();
System.out.println("AVL:" + (t2 - t1));
t1 = System.currentTimeMillis();
for (int key : list) {
bstTreeMap.put(key, "__" + key);
}
t2 = System.currentTimeMillis();
System.out.println("BST:" + (t2 - t1));
// 比对查询
System.out.println("\nsearch");
t1 = System.currentTimeMillis();
for (int i = 0; i < MAX; i++) {
avlTreeMap.get(i);
}
t2 = System.currentTimeMillis();
System.out.println("AVL:" + (t2 - t1));
t1 = System.currentTimeMillis();
for (int i = 0; i < MAX; i++) {
bstTreeMap.get(i);
}
t2 = System.currentTimeMillis();
System.out.println("BST:" + (t2 - t1));
// avlTreeMap.showTree();
// 比对删除
System.out.println("\nremove");
t1 = System.currentTimeMillis();
Collections.shuffle(list);
for (int key : list) {
avlTreeMap.remove(key);
}
t2 = System.currentTimeMillis();
System.out.println("AVL:" + (t2 - t1));
t1 = System.currentTimeMillis();
for (int key : list) {
bstTreeMap.remove(key);
}
t2 = System.currentTimeMillis();
System.out.println("BST:" + (t2 - t1));
avlTreeMap.checkTree();
}
public static void t3() {
Map<Integer, String> map = new AVLTreeMap<>( (o1, o2) ->{
if (null == o1 && null == o2) {
return 0;
}
if (null == o1 && null != o2) {
return -1;
}
if (null != o1 && null == o2) {
return 1;
}
return o1 - o2;
});
int[] arr = new int[]{20,10,21,5,15,22,13,16};
for (int i : arr) {
map.put(i, "__" + String.valueOf(i));
}
AVLTreeMap avlTreeMap = (AVLTreeMap) map;
avlTreeMap.showTree();
avlTreeMap.remove(10);
System.out.println("\n");
avlTreeMap.showTree();
avlTreeMap.checkTree();
}
public static void t2() {
Map<Integer, String> map = new AVLTreeMap<>( (o1, o2) ->{
if (null == o1 && null == o2) {
return 0;
}
if (null == o1 && null != o2) {
return -1;
}
if (null != o1 && null == o2) {
return 1;
}
return o1 - o2;
});
int[] arr = new int[]{8,3,6,1,2,98,2,6,150,170,160,7,52,28,75,14,
40,86,10,21,46,25};
for (int i : arr) {
map.put(i, "__" + String.valueOf(i));
}
AVLTreeMap avlTreeMap = (AVLTreeMap) map;
avlTreeMap.showTree();
avlTreeMap.remove(7);
System.out.println("\n\n\n");
avlTreeMap.showTree();
avlTreeMap.checkTree();
}
public static void t1() {
Map<Integer, String> map = new AVLTreeMap<>( (o1, o2) ->{
if (null == o1 && null == o2) {
return 0;
}
if (null == o1 && null != o2) {
return -1;
}
if (null != o1 && null == o2) {
return 1;
}
return o1 - o2;
});
int[] arr = new int[]{8,3,6,1,2,98,2,6,150,170,160,7,52,28,75,14,
40,86,10,21,46,25};
for (int i : arr) {
map.put(i, "__" + String.valueOf(i));
}
AVLTreeMap avlTreeMap = (AVLTreeMap) map;
avlTreeMap.showTree();
System.out.println(map.get(3));
System.out.println(map.get(6));
System.out.println(map.get(98));
System.out.println(map.get(null));
Set<Integer> set = avlTreeMap.keySet();
for (Integer i : set) {
System.out.println(i);
}
System.out.println();
HashSet<Integer> hashSet = new HashSet<>();
for (int i : arr) {
hashSet.add(i);
}
for (int i : hashSet) {
if (!set.contains(i)) {
System.out.println(false);
}
}
System.out.println(set.size() + " " + hashSet.size());
System.out.println("containsKey 3: " + avlTreeMap.containsKey(3));
System.out.println("containsKey 4: " + avlTreeMap.containsKey(4));
System.out.println("containsValue __3: " + avlTreeMap.containsValue("__3"));
System.out.println("containsValue __4: " + avlTreeMap.containsValue("__4"));
System.out.println();
Set<Map.Entry<Integer, String>> entrySet = avlTreeMap.entrySet();
for (Map.Entry<Integer, String> item : entrySet) {
System.out.println(item.getKey() + ": " + item.getValue());
}
avlTreeMap.checkTree();
}
}