AVLTree及其时间复杂度推导过程

/**
 *
 * @author haofan.whf
 * @version $Id: AVLNode.java, v 0.1 2018年12月19日 10:30 haofan.whf Exp $
 */
public class AVLNode extends Node {

    private int height;

    /**
     * Getter method for property height.
     *
     * @return property value of height
     */
    public int getHeight() {
        return height;
    }

    /**
     * Setter method for property height.
     *
     * @param height value to be assigned to property height
     */
    public void setHeight(int height) {
        this.height = height;
    }

    @Override
    public String toString() {
        return "Node{" +
                (this.getLeft() == null ? "" : "left=" + this.getLeft().getValue()) +
                (this.getRight() == null ? "" : ",right=" + this.getRight().getValue()) +
                (this.getParent() == null ? "" : ",parent=" + this.getParent().getValue()) +
                ",value=" + this.getValue() + ",height=" + this.getHeight() + "}";
    }
}
/**
 * AVLTree,是一个相对平衡的树
 * 普通BSTTree的操作时间复杂度是O(h),但是在输入是顺序数组的情况下,BSTTree会退化成链表,而O(h)的时间复杂度会退化成O(n)
 * 那么如何将h限定在一个合适的值,比如logN
 * AVLTree就是这样的一种树,他的所有操作时间复杂度都是O(logN)
 *
 * 下面是他时间复杂度为O(logN)推理过程:
 * 假设root的高度是h,root的左子树高度是h-1,root的右子树高度是h-2
 * 这颗树的最大高度是多少的问题其实可以看成在上面描述的树结构中固定长度最少能放进多少个节点
 * 假设T(h)为h高度的AVLTree能放的最少节点数量
 * T(h) = T(h-1) + t(h-2) + 1
 * ==>T(h) > T(h-1) + t(h-2)
 * ==>T(h) > 2 * T(h-2) //这个公式表明高度每减少2,能容纳的节点数量减半
 * ==>T(h) > 2^1 * T(h-2 * 1)
 * ==>T(h) > 2^2 * T(h - 2 * 2)
 * ==>T(h) > 2^k * T(h - 2 * k)
 * ==>考虑基本场景,T(0)可以推出k=h/2
 * ==>T(h) > 2^(h/2)
 * ==>logT(h) > h/2
 * ==>h < 2 * logT(h)
 * ==>假设T(h)=Nh高度的AVLTree能放的最少节点数量为N
 * ==>h < 2 * logN
 * 该结论表明对于节点个数为N的AVLTree的最大高度不超过2*logN,(实际好像是1.44 * logN)
 *
 * 第二种推理方式,T(h) = T(h-1) + t(h-2) + 1,不看这个1个话
 * T(h) = T(h-1) + t(h-2)其实是斐波那契数列
 * google可以知道T(h-1) + t(h-2)其实是以某个常数X为底的h次幂
 * 同样也可以推导出 h < X * logN
 * @author haofan.whf
 * @version $Id: AVLTree.java, v 0.1 2018年12月18日 10:33 haofan.whf Exp $
 */
public class AVLTree extends BinarySearchTree {

    @Override
    public void doAfterInsert(Node node) {
        //step1.当插入一个节点时,该节点的高度需要被更新,他的所有父节点(可能)需要更新
        //但是这里并不判断每个父节点是否真的需要更新,对于渐进时间复杂度来说这样做并没有意义
        while (node != null){
            updateHeight(node);
            //step2.在更新完节点高度之后(节点是否平衡依托于height,所以当判断一个节点是否平衡需要先其子节点的高度)
            //需要判定节点是否平衡,不平衡需要rebalance
            rebalanceIfNeed(node);

            //step3.继续向上
            node = node.getParent();
        }
    }

    @Override
    public void doAfterDelete(Node parent) {
        //step1.当删除一个节点时,该节点的父节点的高度需要被更新,他的所有父节点(可能)需要更新
        //但是这里并不判断每个父节点是否真的需要更新,对于渐进时间复杂度来说这样做并没有意义
        while (parent != null){
            updateHeight(parent);
            //step2.在更新完节点高度之后(节点是否平衡依托于height,所以当判断一个节点是否平衡需要先其子节点的高度)
            //需要判定节点是否平衡,不平衡需要rebalance
            rebalanceIfNeed(parent);

            //step3.继续向上
            parent = parent.getParent();
        }
    }

    /**
     * 检查当前节点是否需要重新平衡
     * @param node
     * @return
     */
    private void rebalanceIfNeed(Node node){

        int leftHeight = getNodeHeight(node.getLeft());
        int rightHeight = getNodeHeight(node.getRight());
        if(Math.abs(leftHeight - rightHeight) <= 1){
            return;
        }
        //step1.首先判断是哪颗子树重
        if(leftHeight < rightHeight){
            //右边重,两种情况
            if(getNodeHeight(node.getRight().getRight()) >= getNodeHeight(node.getRight().getLeft())){
                //1)一次leftRotate(node)
                leftRotate(node);
            }else{
                //2)一次rightRotate(node.right) + leftRotate(node)
                rightRotate(node.getRight());
                leftRotate(node);
            }
        }else{
            //左边重,类似
            if(getNodeHeight(node.getLeft().getLeft()) >= getNodeHeight(node.getLeft().getRight())){
                //1)一次rightRotate(node)
                rightRotate(node);
            }else{
                //2)一次leftRotate(node.left) + rightRotate(node)
                leftRotate(node.getLeft());
                rightRotate(node);
            }
        }
    }

    /**
     * 左旋转
     * @param node
     */
    private void leftRotate(Node node){
        int leftOrRight = leftOrRight(node);

        node.getRight().setParent(node.getParent());
        if(leftOrRight == 1){
            //node是右节点
            node.getParent().setRight(node.getRight());
        }else if(leftOrRight == -1){
            //node是左节点
            node.getParent().setLeft(node.getRight());
        }else{
            //node没有父节点
            //什么也不做
        }
        node.setParent(node.getRight());

        node.setRight(node.getRight().getLeft());
        if(node.getParent().getLeft() != null){
            node.getParent().getLeft().setParent(node);
        }
        node.getParent().setLeft(node);
        updateHeight(node);
        updateHeight(node.getParent());
    }

    /**
     * 右旋转
     * @param node
     */
    private void rightRotate(Node node){
        int leftOrRight = leftOrRight(node);

        node.getLeft().setParent(node.getParent());
        if(leftOrRight == 1){
            //node是右节点
            node.getParent().setRight(node.getLeft());
        }else if(leftOrRight == -1){
            //node是左节点
            node.getParent().setLeft(node.getLeft());
        }else{
            //node没有父节点
            //什么也不做
        }
        node.setParent(node.getLeft());

        node.setLeft(node.getLeft().getRight());
        if(node.getParent().getRight() != null){
            node.getParent().getRight().setParent(node);
        }
        node.getParent().setRight(node);
        updateHeight(node);
        updateHeight(node.getParent());
    }

    public Node createNode(){
        return new AVLNode();
    }

    /**
     * 更新节点的高度
     * @param node
     */
    private void updateHeight(Node node){
        if(node == null){
            return;
        }
        AVLNode avlNode = parseNode(node);
        avlNode.setHeight(Math.max(getNodeHeight(node.getLeft())
                , getNodeHeight(node.getRight())) + 1);
    }

    /**
     * 查询节点的高度,如果节点为null则默认返回-1
     * @param node
     * @return
     */
    private int getNodeHeight(Node node){
        if(node == null){
            return -1;
        }
        AVLNode avlNode = parseNode(node);
        return avlNode.getHeight();
    }

    /**
     * throw runtime exception unless node is AVLNode
     */
    private AVLNode parseNode(Node node){
        if(node instanceof AVLNode){
            return (AVLNode)node;
        }else {
            throw new RuntimeException("node type not match");
        }
    }

    /**
     * 对AVLTree结构进行修改后调用此方法查看结构的完整性
     * T(N) = O(N)
     * 需要遍历每个节点查看其左右子树的高度差是否<=1
     * @param root
     */
    public void checkRI(Node root){
        if(root == null){
            return;
        }
        if(Math.abs(getNodeHeight(root.getLeft()) - getNodeHeight(root.getRight())) > 1){
            throw new RuntimeException("it's not a AVLTree");
        }
        checkBasicRI(root);
        checkRI(root.getRight());
        checkRI(root.getLeft());
    }
}
/**
 *
 * @author haofan.whf
 * @version $Id: BSTTest.java, v 0.1 2018年12月11日 下午7:34 haofan.whf Exp $
 */
public class BSTTest {

    @Test
    public void avlTest(){
        AVLTree avl = new AVLTree();

        int[] insertArray = randomArray(1000);
        //insertArray = new int[]{3,2,0,5,6,9};
        Node root = new AVLNode();
        for (int i = 0; i < insertArray.length; i++) {
            //这里注意AVLTree在新增节点时有可能rebalance导致root不再是这颗AVLTree的root
            //这样在下一次insert时insert的节点有可能出现在不该出现的地方
            //推演3,2,0,5,6,9这个例子就明白了,所以在每次插入之前保证root是该Tree的root
            root = findRoot(root);
            avl.insert(root, insertArray[i]);
            avl.checkRI(root);
        }

        int[] deleteArray = randomArray(100);
        //deleteArray = new int[]{3,7,0,4,8};
        for (int i = 0; i < deleteArray.length; i++) {
            root = findRoot(root);
            avl.delete(root, deleteArray[i]);
            avl.checkRI(root);
        }

    }

    private Node findRoot(Node node){
        if(node.getParent() != null){
            return findRoot(node.getParent());
        }else{
            return node;
        }
    }

    private static int[] randomArray(int size){
        Random random = new Random();
        int[] array = new int[size];
        for (int i = 0; i < size; i++) {
            array[i] = random.nextInt(size);
        }
        return array;
    }
}

你可能感兴趣的:(AVLTree及其时间复杂度推导过程)