数据结构与算法分析(三) —— AVL树的实现

本文实现了AVL树,有几个注意点:

insert和remove引起的失衡都可以用同样的旋转进行修复:

  • Case1:k2的左子树的左子树比其右子树高2 —— k2进行左单旋
  • Case4:k1的右子树的右子树比其左子树高2 —— k1进行右单旋
  • Case2:k3的左子树的右子树比其右子树高2 —— k3进行左双旋 —— k3左子树先进行右单旋再对k3进行左单旋
  • Case3:k1的右子树的左子树比其左子树高2 —— k1进行右双旋 —— k1右子树先进行左单旋再对k1进行右单旋
//  AvlTree class
//
//  CONSTRUCTION: with no initializer
//
//  ******************PUBLIC OPERATIONS*********************
//  void insert( x )       --> Insert x
//  void remove( x )       --> Remove x (unimplemented)
//  boolean contains( x )  --> Return true if x is present
//  boolean remove( x )    --> Return true if x was present
//  Comparable findMin( )  --> Return smallest item
//  Comparable findMax( )  --> Return largest item
//  boolean isEmpty( )     --> Return true if empty; else false
//  void makeEmpty( )      --> Remove all items
//  void printTree( )      --> Print tree in sorted order
//  ******************ERRORS********************************
//  Throws UnderflowException as appropriate

/*
 *  Implements an AVL tree.
 *  Note that all "matching" is based on the compareTo method.
 */


public class AVLTree>
{
    //@//   Fields
    private static final int Allowed_imbalance = 1;
    private AVLNode root;


    //@//   Constructors
    public AVLTree()
    {
        root = null;
    }


    //@//   Classes
    private static class AVLNode
    {
        AVLNode(AnyType theElement)
        {
            this(theElement,null,null);
        }
        AVLNode(AnyType theElement, AVLNode lt, AVLNode rt)
        {
            element = theElement;
            left = lt;
            right = rt;
            height = 0;
        }

        AnyType element;
        AVLNode left;
        AVLNode right;
        int height;
    }


    //@//   Methods
    public void makeEmpty()
    {
        root = null;
    }

    public boolean isEmpty()
    {
        return root==null;
    }

    public void printTree()
    {
        if(isEmpty())
            System.out.println("Empty Tree");
        else
            printTree(root);
    }

    private void printTree(AVLNode t)
    {
        if(t!=null)
        {
            printTree(t.left);
            System.out.println(t.right);
            print(t.right);
        }
    }

    private int height(AVLNode t)
    {
        return t==null ? -1 : t.height;
    }

    public AnyType findMin()
    {
        if(isEmpty())
            throw new UnderflowException();
        return findMin(root).element;
    }

    private AVLNode findMin(AVLNode t)
    {
        if(t==null)
            return t;
        while(t.left!=null)
            t = t.left;
        return t;
    }

    public AnyType findMax()
    {
        if(isEmpty())
            throw new UnderflowException();
        return findMax(root).element;
    }

    private AVLNode findMax(AVLNode t)
    {
        if(t==null)
            return t;
        while(t.right!=null)
            t = t.right;
        return t;
    }

    public boolean contains(AnyType x)
    {
        return contains(x,root);
    }

    private boolean contains(AnyType x, AVLNode t)
    {
        while(t!=null)
        {
            int compareResult = x.compareTo(t.element);

            if(compareResult<0)
                t = t.left;
            else if(compareResult>0)
                t = t.right;
            else return true;
        }
        return false;
    }

    public void insert(AnyType x)
    {
        root = insert(x,root);
    }

    private AVLNode insert(AnyType x, AVLNode t)
    {
        if(t==null)
            return new AVLNode<>(x,null,null);

        int compareResult = x.compareTo(t.element);
        if(compareResult<0)
            t.left = insert(x,t.left);
        else if(compareResult>0)
            t.right = insert(x,t.right);
        else
            ;
        return balance(t);
    }

    public void remove(AnyType x)
    {
        root = remove(x,root);
    }

    private AVLNode remove(AnyType x, AVLNode t)
    {
        if(t==null)
            return t;

        int compareResult = x.compareTo(t.element);
        if(compareResult<0)
            t.left = remove(x,t.left);
        else if(compareResult>0)
            t.right = remove(x,t.right);
        else if(t.left!=null && t.right!=null)
        {
            t.element = findMin(t.right).element;
            t.right = remove(t.element, t.right);
        }
        else
            t = (t.left!=null) ? t.left : t.right;
        return balance(t);
    }

    private AVLNode balance(AVLNode t)
    {
        if(t==null)
            return t;

        if(height(t.left)-height(t.right) > Allowed_imbalance)
            if(height(t.left.left)>=height(t.left.right))
                t = rotateWithLeftChild(t);
            else
                t = doubleWithLeftChild(t);
        else if(height(t.right)-height(t.left) > Allowed_imbalance)
            if(height(t.right.right)>=height(t.right.left))
                t = rotateWithRightChild(t);
            else
                t = doubleWithRightChild(t);
        t.height = Math.max(height(t.left),height(t.right)) + 1;
        return t;
    }

    public void checkBalance()
    {
        checkBalance(root);
    }

    private int checkBalance(AVLNode t)
    {
        if(t==null)
            return -1;
        if(t!=null)
        {
            int hl = checkBalance(t.left);
            int hr = checkBalance(t.right);
            if(Math.abs(height(t.left)-height(t.right))>1 || height( t.left ) != hl || height( t.right ) != hr )
                System.out.println("OOPS!!");
        }
        return height(t);
    }

    private AVLNode rotateWithLeftChild(AVLNode k2)
    {
        AVLNode k1 = k2.left;
        k2.left = k1.right;
        k1.right = k2;
        k2.height = Math.max(height(k2.left),height(k2.right)) + 1;
        k1.height = Math.max(height(k1.left),k2.height) + 1;
        return k1;      
    }

    private AVLNode rotateWithRightChild(AVLNode k1)
    {
        AVLNode k2 = k1.left;
        k1.right = k2.left;
        k2.left = k1;
        k1.height = Math.max(height(k1.left),height(k1.right)) + 1;
        k2.height = Math.max(height(k2.left),k1.height) + 1;
        return k2;      
    }

    private AVLNode doubleWithLeftChild(AVLNode k3)
    {
        k3.left = rotateWithRightChild( k3.left );
        return rotateWithLeftChild( k3 );
    }

    private AVLNode doubleWithRightChild(AVLNode k1)
    {
        k1.right = rotateWithLeftChild( k1.right );
        return rotateWithRightChild( k1 );
    }


    //@//   Test Program
    public static void main( String [ ] args )
    {
        AVLTree t = new AVLTree<>( );
        final int SMALL = 40;
        final int NUMS = 1000000;  // must be even
        final int GAP  =   37;

        System.out.println( "Checking... (no more output means success)" );

        for( int i = GAP; i != 0; i = ( i + GAP ) % NUMS )
        {
        //    System.out.println( "INSERT: " + i );
            t.insert( i );
            if( NUMS < SMALL )
                t.checkBalance( );
        }

        for( int i = 1; i < NUMS; i+= 2 )
        {
         //   System.out.println( "REMOVE: " + i );
            t.remove( i );
            if( NUMS < SMALL )
                t.checkBalance( );
        }
        if( NUMS < SMALL )
            t.printTree( );
        if( t.findMin( ) != 2 || t.findMax( ) != NUMS - 2 )
            System.out.println( "FindMin or FindMax error!" );

        for( int i = 2; i < NUMS; i+=2 )
             if( !t.contains( i ) )
                 System.out.println( "Find error1!" );

        for( int i = 1; i < NUMS; i+=2 )
        {
            if( t.contains( i ) )
                System.out.println( "Find error2!" );
        }
    }

}

你可能感兴趣的:(Java,数据结构与算法分析)