本文实现了AVL树,有几个注意点:
insert和remove引起的失衡都可以用同样的旋转进行修复:
// AvlTree class
//
// CONSTRUCTION: with no initializer
//
// ******************PUBLIC OPERATIONS*********************
// void insert( x ) --> Insert x
// void remove( x ) --> Remove x (unimplemented)
// boolean contains( x ) --> Return true if x is present
// boolean remove( x ) --> Return true if x was present
// Comparable findMin( ) --> Return smallest item
// Comparable findMax( ) --> Return largest item
// boolean isEmpty( ) --> Return true if empty; else false
// void makeEmpty( ) --> Remove all items
// void printTree( ) --> Print tree in sorted order
// ******************ERRORS********************************
// Throws UnderflowException as appropriate
/*
* Implements an AVL tree.
* Note that all "matching" is based on the compareTo method.
*/
public class AVLTree>
{
//@// Fields
private static final int Allowed_imbalance = 1;
private AVLNode root;
//@// Constructors
public AVLTree()
{
root = null;
}
//@// Classes
private static class AVLNode
{
AVLNode(AnyType theElement)
{
this(theElement,null,null);
}
AVLNode(AnyType theElement, AVLNode lt, AVLNode rt)
{
element = theElement;
left = lt;
right = rt;
height = 0;
}
AnyType element;
AVLNode left;
AVLNode right;
int height;
}
//@// Methods
public void makeEmpty()
{
root = null;
}
public boolean isEmpty()
{
return root==null;
}
public void printTree()
{
if(isEmpty())
System.out.println("Empty Tree");
else
printTree(root);
}
private void printTree(AVLNode t)
{
if(t!=null)
{
printTree(t.left);
System.out.println(t.right);
print(t.right);
}
}
private int height(AVLNode t)
{
return t==null ? -1 : t.height;
}
public AnyType findMin()
{
if(isEmpty())
throw new UnderflowException();
return findMin(root).element;
}
private AVLNode findMin(AVLNode t)
{
if(t==null)
return t;
while(t.left!=null)
t = t.left;
return t;
}
public AnyType findMax()
{
if(isEmpty())
throw new UnderflowException();
return findMax(root).element;
}
private AVLNode findMax(AVLNode t)
{
if(t==null)
return t;
while(t.right!=null)
t = t.right;
return t;
}
public boolean contains(AnyType x)
{
return contains(x,root);
}
private boolean contains(AnyType x, AVLNode t)
{
while(t!=null)
{
int compareResult = x.compareTo(t.element);
if(compareResult<0)
t = t.left;
else if(compareResult>0)
t = t.right;
else return true;
}
return false;
}
public void insert(AnyType x)
{
root = insert(x,root);
}
private AVLNode insert(AnyType x, AVLNode t)
{
if(t==null)
return new AVLNode<>(x,null,null);
int compareResult = x.compareTo(t.element);
if(compareResult<0)
t.left = insert(x,t.left);
else if(compareResult>0)
t.right = insert(x,t.right);
else
;
return balance(t);
}
public void remove(AnyType x)
{
root = remove(x,root);
}
private AVLNode remove(AnyType x, AVLNode t)
{
if(t==null)
return t;
int compareResult = x.compareTo(t.element);
if(compareResult<0)
t.left = remove(x,t.left);
else if(compareResult>0)
t.right = remove(x,t.right);
else if(t.left!=null && t.right!=null)
{
t.element = findMin(t.right).element;
t.right = remove(t.element, t.right);
}
else
t = (t.left!=null) ? t.left : t.right;
return balance(t);
}
private AVLNode balance(AVLNode t)
{
if(t==null)
return t;
if(height(t.left)-height(t.right) > Allowed_imbalance)
if(height(t.left.left)>=height(t.left.right))
t = rotateWithLeftChild(t);
else
t = doubleWithLeftChild(t);
else if(height(t.right)-height(t.left) > Allowed_imbalance)
if(height(t.right.right)>=height(t.right.left))
t = rotateWithRightChild(t);
else
t = doubleWithRightChild(t);
t.height = Math.max(height(t.left),height(t.right)) + 1;
return t;
}
public void checkBalance()
{
checkBalance(root);
}
private int checkBalance(AVLNode t)
{
if(t==null)
return -1;
if(t!=null)
{
int hl = checkBalance(t.left);
int hr = checkBalance(t.right);
if(Math.abs(height(t.left)-height(t.right))>1 || height( t.left ) != hl || height( t.right ) != hr )
System.out.println("OOPS!!");
}
return height(t);
}
private AVLNode rotateWithLeftChild(AVLNode k2)
{
AVLNode k1 = k2.left;
k2.left = k1.right;
k1.right = k2;
k2.height = Math.max(height(k2.left),height(k2.right)) + 1;
k1.height = Math.max(height(k1.left),k2.height) + 1;
return k1;
}
private AVLNode rotateWithRightChild(AVLNode k1)
{
AVLNode k2 = k1.left;
k1.right = k2.left;
k2.left = k1;
k1.height = Math.max(height(k1.left),height(k1.right)) + 1;
k2.height = Math.max(height(k2.left),k1.height) + 1;
return k2;
}
private AVLNode doubleWithLeftChild(AVLNode k3)
{
k3.left = rotateWithRightChild( k3.left );
return rotateWithLeftChild( k3 );
}
private AVLNode doubleWithRightChild(AVLNode k1)
{
k1.right = rotateWithLeftChild( k1.right );
return rotateWithRightChild( k1 );
}
//@// Test Program
public static void main( String [ ] args )
{
AVLTree t = new AVLTree<>( );
final int SMALL = 40;
final int NUMS = 1000000; // must be even
final int GAP = 37;
System.out.println( "Checking... (no more output means success)" );
for( int i = GAP; i != 0; i = ( i + GAP ) % NUMS )
{
// System.out.println( "INSERT: " + i );
t.insert( i );
if( NUMS < SMALL )
t.checkBalance( );
}
for( int i = 1; i < NUMS; i+= 2 )
{
// System.out.println( "REMOVE: " + i );
t.remove( i );
if( NUMS < SMALL )
t.checkBalance( );
}
if( NUMS < SMALL )
t.printTree( );
if( t.findMin( ) != 2 || t.findMax( ) != NUMS - 2 )
System.out.println( "FindMin or FindMax error!" );
for( int i = 2; i < NUMS; i+=2 )
if( !t.contains( i ) )
System.out.println( "Find error1!" );
for( int i = 1; i < NUMS; i+=2 )
{
if( t.contains( i ) )
System.out.println( "Find error2!" );
}
}
}