平衡二叉树的python实现

class Node:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None

    def __str__(self):
        return str(self.val)

    __repr__ = __str__


def depth(node):
    if node is None:
        return 0
    d1 = depth(node.left)
    d2 = depth(node.right)
    return max(d1, d2) + 1


def left_rotate(root: Node):
    tmp = root.right
    root.right = tmp.left
    tmp.left = root
    return tmp


def right_rotate(root):
    new_root = root.left
    root.left = new_root.right
    new_root.right = root
    return new_root


def left_right_rotate(root):
    root = left_rotate(root.left)
    return right_rotate(root)


def right_left_rotate(root):
    return left_rotate(right_rotate(root.right))


def get_left(root: Node, p=None):
    if root is None or root.left is None:
        return root, p
    return get_left(root.left, root)


def get_right(root: Node, p=None):
    if root is None or root.right is None:
        return root, p
    return get_right(root.right, root)


def remove_node(root, t, parent=None):
    if root == t:
        is_left = parent.left == root
        if root.left is None and root.right is None:
            if is_left:
                parent.left = None
            else:
                parent.right = None
        elif root.left is None:
            if is_left:
                parent.left = root.right
            else:
                parent.right = root.right
        elif root.right is None:
            if is_left:
                parent.left = root.left
            else:
                parent.right = root.left
        else:
            min_node, p = get_left(root.right)
            root.val = min_node.val
            if min_node == root.right:
                p.right = min_node.right
            else:
                p.left = min_node.right
        return
    if t.val < root.val:
        remove_node(root.left, t, root)
    else:
        remove_node(root.right, t, root)


def free_tree(root):
    if root is None:
        return
    free_tree(root.left)
    free_tree(root.right)
    del root


def insert(root: Node, v):
    if root is None:
        return Node(v)
    if v < root.val:
        root.left = insert(root.left, v)
        if depth(root.left) - depth(root.right) >= 2:
            if v < root.left.val:
                root = right_rotate(root)
            else:
                root = left_right_rotate(root)
    else:
        root.right = insert(root.right, v)
        if depth(root.right) - depth(root.left) >= 2:
            if v > root.right.val:
                root = left_rotate(root)
            else:
                root = right_left_rotate(root)
    return root


def walk_node(root, front=None, mid=None, back=None):
    if not root:
        return
    if front:
        front(root)

    walk_node(root.left)
    if mid:
        mid(root)

    walk_node(root.right)
    if back:
        back(root)


def horizontal_walk(root: Node, handle):
    ls = [root]
    while 1:
        t = ls
        ls = []
        for node in t:
            if not node:
                continue
            handle(node)
            ls.append(node.left)
            ls.append(node.right)
        if not ls:
            break


class BinTree:
    def __init__(self):
        self.root = None

    def push(self, val):
        self.root = insert(self.root, val)

    def pop_min(self):
        min_node = get_left(self.root)[0]
        self.remove(min_node)
        return min_node

    def pop_max(self):
        max_node = get_right(self.root)[0]
        self.remove(max_node)
        return max_node

    def remove(self, node):
        remove_node(self.root, node)

    def __del__(self):
        free_tree(self.root)


if __name__ == '__main__':
    bt = BinTree()
    for i in range(10):
        bt.push(i)
    horizontal_walk(bt.root, lambda x: print(x))
    print(bt.pop_min().val)
    print(bt.pop_max().val)
    print(bt.pop_max().val)


 

你可能感兴趣的:(Python)