b+树删除操作

通过查阅网上资料和自己的实践,发现b+树的删除有两种实现

只删除叶子节点的 key

  • 找到 key 所在叶子节点
  • 在叶子节点删除 key
  • 自底(叶子节点)向上修复树的平衡
    • 先判断是否满足最少条件,如果不满足,则需要借

      • 先看子节点的兄弟节点能不能借,如果可以,借一个过来,这个操作也被称为“左旋” “右旋”

      • 兄弟节点都不富裕,合并两个子节点

    • 向上重复上述操作

同时删除叶子节点和中间节点的 key

  • 删除分几种情况

1.删除的 key 不在中间节点和根节点,直接删除
2.删除的 key 在中间节点和根节点,这个也分几种情况
2.1.删除的 key 在子树可以找到后继
使用后继替换删除的 key ,在叶子节点中删除 key ,自底向上维护树(类似 b 树)
2.2.删除的 key 没有后继,即 key 为最大值
直接删除对应 key 和 child ,自底向上维护树(类似 b 树)

  • 自底(叶子节点)向上修复树的平衡
    • 先判断是否满足最少条件,如果不满足,则需要借

      • 先看子节点的兄弟节点能不能借,如果可以,借一个过来,这个操作也被称为“左旋” “右旋”

      • 兄弟节点都不富裕,合并两个子节点

    • 向上重复上述操作

如果有错误欢迎指出,我也不是很确定是否考虑到了各种情况,网上别人的实现也看不懂( ̄▽ ̄)"

下面是实现代码

remove_key_leaf 对应只删除叶子节点的 key

remove_key_both 对应同时删除叶子节点和中间节点的 key

import typing
import random


def log(*args):
    print(*args)


# noinspection DuplicatedCode
class Node:
    def __init__(self, degree, leaf):
        # 是否是叶子节点
        self.leaf = leaf
        # 每个节点的 key 数目: degree - 1 <= n_key <= 2 * degree - 1
        # 每个节点的 children 数目: degree <= n_children <= 2 degree
        self.degree = degree
        self.keys = []
        self.children = []
        # 叶子节点组成双向链表
        self.prov_leaf = None
        self.next_leaf = None

    def search_child(self, key) -> typing.Optional[str]:
        if self.leaf:
            try:
                index = self.keys.index(key)
            except ValueError:
                return None
            value = self.children[index]
            return value
        else:
            i = 0
            n = len(self.keys)
            while i < n and key >= self.keys[i]:
                i = i + 1
            child = self.children[i]
            return child.search_child(key)

    def split_child(self, index_child):
        log('split_child1', index_child, self.degree, self.keys, self.children)
        node_old: Node = self.children[index_child]
        node_new = Node(node_old.degree, node_old.leaf)
        log('split_child2', node_old.keys, node_old.children)

        # 新节点取老节点最大的右半部分
        node_new.keys = node_old.keys[self.degree:]
        # my: 叶子节点需要保存所有 key ,中间节点可以不用保存提升到父节点的key
        if node_new.leaf:
            # 把老节点中间的 key 插入父节点
            self.keys.insert(index_child, node_old.keys[self.degree])
            node_old.keys = node_old.keys[:self.degree]
            # 维护叶子双向链表
            node_new.prev_leaf = node_old
            node_new.next_leaf = node_old.next_leaf
            if node_new.next_leaf is not None:
                node_new.next_leaf.prev_leaf = node_new
            node_old.next_leaf = node_new
        else:
            # 把老节点中间的 key 插入父节点
            # my: 中间节点可以不用保存提升到父节点的key
            # my: 减一是因为右边的 key 已经给了新节点,这里只操作老节点的key
            self.keys.insert(index_child, node_old.keys[self.degree - 1])
            node_old.keys = node_old.keys[:self.degree - 1]
        log('split_child3', node_new.keys)
        # 把 children 一起复制
        node_new.children = node_old.children[self.degree:]
        node_old.children = node_old.children[:self.degree]
        # 把新节点插回父节点
        self.children.insert(index_child + 1, node_new)
        log('split_child4', node_new.children)

    def insert_non_full(self, key, value):
        i = len(self.keys) - 1
        if self.leaf:
            while i >= 0 and key < self.keys[i]:
                i = i - 1
            self.keys.insert(i + 1, key)
            self.children.insert(i + 1, value)
            log('insert_non_full 1', key, self.keys)
        else:
            while i >= 0 and key < self.keys[i]:
                i = i - 1
            # i 的最小值为 -1,children 的最小下标是 0,所以要加一。
            i = i + 1
            log('insert_non_full 2', key, self.keys, self.children, i)
            child: Node = self.children[i]
            # 如果满了,就先 split 再插入
            if child.full():
                self.split_child(i)
                # split 了之后,有两个节点,需要判断往哪个节点插入
                if key > self.keys[i]:
                    i = i + 1
            child: Node = self.children[i]
            # 因为提前 split 过了,所以真正插入的时候肯定是不全满的
            child.insert_non_full(key, value)

    def update(self, key, value):
        if self.leaf:
            try:
                index = self.keys.index(key)
            except ValueError:
                log(f"update key not find: {key}")
                return
            self.children[index] = value
        else:
            i = 0
            n = len(self.keys)
            while i < n and key >= self.keys[i]:
                i = i + 1
            child = self.children[i]
            return child.update(key, value)

    def remove_key_leaf(self, key):
        try:
            log(f'remove_key_leaf self.keys {self.keys}')
            index = self.keys.index(key)
            if self.leaf:
                self.keys.pop(index)
                self.children.pop(index)
            else:
                child = self.children[index + 1]
                # 递归删除
                child.remove_key_leaf(key)
                self.repair_child(index + 1)
        except ValueError:
            if self.leaf:
                log(f'remove_key_leaf not find key: {key}')
            else:
                i = 0
                n = len(self.keys)
                while i < n and key >= self.keys[i]:
                    i += 1
                log(f'remove_key_leaf search child {self.keys}, i: {i}')
                child = self.children[i]
                child.remove_key_leaf(key)
                self.repair_child(i)

    def remove_key_both(self, key):
        # 删除分几种情况
        # 1.删除的 key 不在中间节点和根节点,直接删除
        # 2.删除的 key 在中间节点和根节点,这个也分几种情况
        #   2.1.删除的 key 在子树可以找到后继
        #       使用后继替换删除的 key ,在叶子节点中删除 key ,自底向上维护树(类似 b 树)
        #   2.2.删除的 key 没有后继,即 key 为最大值
        #       直接删除对应 key 和 child ,自底向上维护树(类似 b 树)
        try:
            log(f'remove_key_both self.keys {self.keys}')
            index = self.keys.index(key)
            if self.leaf:
                self.keys.pop(index)
                self.children.pop(index)
            else:
                child = self.children[index + 1]
                key_successor = child.search_child_successor(key)
                if key_successor is not None:
                    self.keys[index] = key_successor
                    child.remove_key_both(key)
                    self.repair_child(index + 1)
                else:
                    self.keys.pop(index)
                    self.children.pop(index + 1)
        except ValueError:
            if self.leaf:
                log(f'remove_key_both not find key: {key}')
            else:
                i = 0
                n = len(self.keys)
                while i < n and key >= self.keys[i]:
                    i += 1
                log(f'remove_key_both search child {self.keys}, i: {i}')
                child = self.children[i]
                child.remove_key_both(key)
                self.repair_child(i)


    def search_child_successor(self, key):
        if self.leaf:
            try:
                index = self.keys.index(key)
            except ValueError:
                log(f"search_child_successor leaf not find key {key}")
                return None

            if index + 1 < len(self.keys):
                return self.keys[index + 1]
            # 在下一个叶子节点
            elif self.next_leaf is not None:
                return self.next_leaf.keys[0]
            else:
                return None
        else:
            i = 0
            n = len(self.keys)
            while i < n and key >= self.keys[i]:
                i += 1
            child = self.children[i]
            return child.search_child_successor(key)

    def repair_child(self, child_index):
        child = self.children[child_index]
        if child.enough():
            return
        # 节点太少,不符合要求
        if child_index > 0 and self.children[child_index - 1].can_borrow():
            # 向左边兄弟节点借
            self.rotate_right(child_index)
        elif child_index < len(self.children) - 1 and self.children[child_index + 1].can_borrow():
            # 向右边兄弟节点借
            self.rotate_left(child_index)
        else:
            # 合并子节点
            if child_index < len(self.children) - 1:
                self.merge_right_child(child_index)
            else:
                self.merge_right_child(child_index - 1)

    def rotate_left(self, child_index):
        child = self.children[child_index]
        child_right = self.children[child_index + 1]
        if child.leaf:
            child.keys.append(child_right.keys.pop(0))
            child.children.append(child_right.children.pop(0))
            self.keys[child_index] = child_right.keys[0]
        else:
            child.keys.append(self.keys[child_index])
            child.children.append(child_right.children.pop(0))
            self.keys[child_index] = child_right.keys.pop(0)

    def rotate_right(self, child_index):
        child = self.children[child_index]
        child_left = self.children[child_index - 1]
        if child.leaf:
            child.keys.insert(0, child_left.keys.pop(-1))
            child.children.insert(0, child_left.children.pop(-1))
            self.keys[child_index - 1] = child.keys[0]
        else:
            child.keys.insert(0, self.keys[child_index - 1])
            child.children.insert(0, child_left.children.pop(-1))
            self.keys[child_index - 1] = child_left.keys.pop(-1)

    def merge_right_child(self, child_index):
        log(f'merge_right_child {child_index}')
        child = self.children[child_index]
        child_right = self.children.pop(child_index + 1)
        if child.leaf:
            self.keys.pop(child_index)
            # 维护叶子双向链表
            child.next_leaf = child_right.next_leaf
            if child.next_leaf is not None:
                child.next_leaf.prev_leaf = child
        else:
            child.keys.append(self.keys.pop(child_index))
        child.keys.extend(child_right.keys)
        child.children.extend(child_right.children)

    def full(self):
        log('full', self.degree, len(self.keys))
        return len(self.keys) == 2 * self.degree - 1

    def enough(self):
        return len(self.keys) >= self.degree - 1

    def can_borrow(self):
        return len(self.keys) >= self.degree

    def show(self, count):
        indent = '---- ' * count
        keys = ','.join([str(k) for k in self.keys])
        log(f'{indent}key:{keys}')
        if not self.leaf:
            for v in self.children:
                v.show(count + 1)
        else:
            values = ','.join(self.children)
            log(f'{indent}values:{values}')

    def check(self, is_root, low, high):
        # 检查 key 大小顺序
        self.check_order(low, high)

        if self.leaf:
            if is_root:
                return
            self.check_degree()
        else:
            key_order = [low]
            key_order.extend(self.keys)
            key_order.append(high)
            for i, child in enumerate(self.children):
                child.check(False, key_order[i], key_order[i + 1])
            if not is_root:
                self.check_degree()

    def check_order(self, low, high):
        if len(self.keys) == 0:
            return
        if self.keys[0] < low or self.keys[-1] >= high:
            log(f'check_order error, keys: {self.keys}')
            raise ValueError('check_order')
        for i in range(len(self.keys)):
            if i == len(self.keys) - 1:
                break
            if self.keys[i] >= self.keys[i + 1]:
                log(f'check_order error, keys: {self.keys}')
                raise ValueError('check_order')

    def check_degree(self):
        n = len(self.keys)
        if n < self.degree - 1 or n > 2 * self.degree - 1:
            log(f'check_degree find error, keys: {self.keys}')
            raise ValueError("check_degree")
        if self.leaf:
            if len(self.keys) != len(self.children):
                log(f'check_degree leaf not equal, keys: {self.keys}, values: {self.children}')
                raise ValueError("check_degree")
        else:
            if len(self.keys) + 1 != len(self.children):
                log(f'check_degree not equal, keys: {self.keys}, values: {self.children}')
                raise ValueError("check_degree")

    def __repr__(self):
        return repr(self.keys)


# noinspection DuplicatedCode
class BPlusTree:
    def __init__(self, degree):
        self.degree = degree
        self.root = Node(degree, True)

    def split_root(self):
        # 做一个没有 key 只有一个 children 的节点来做新 root
        # root split 了之后,b plus tree 的高度才会增长
        node_new = Node(self.degree, False)
        node_new.children = [self.root]
        self.root = node_new
        log('split_root', self.root.children)
        self.root.split_child(0)

    def insert(self, key, value):
        log('insert', key, value)
        # 如果 root 满了就 split
        if self.root.full():
            self.split_root()
        self.root.insert_non_full(key, value)

    def delete(self, key):
        log(f'tree delete key {key}')
        self.root.remove_key_leaf(key)
        # 根节点为空,重新选择根节点
        if len(self.root.keys) == 0 and (not self.root.leaf):
            self.root = self.root.children[0]

    def search(self, key):
        return self.root.search_child(key)

    def update(self, key, value):
        log('update', key, value)
        self.root.update(key, value)

    def show(self):
        log(f'+++++++++ degree: {self.degree} +++++++++ ')
        return self.root.show(0)

    def check(self):
        self.root.check(True, float('-inf'), float('inf'))


def test_tree_case(degree, insert_count, delete_count):
    log(f'test_tree_case {degree}, {insert_count}, {delete_count}')
    tree = BPlusTree(degree)
    case = list(range(1, insert_count))
    for i in case:
        assert tree.search(i) is None
    value_length = len(str(insert_count))
    for i in case:
        key = i
        value = str(i).rjust(value_length, "#")
        tree.insert(key, value)
        log(f"tree show {i}")
        tree.show()
        tree.check()
    for i in case:
        assert tree.search(i) is not None
    delete_case = random.sample(case, delete_count)
    log(f'delete case {delete_case}')
    for i in delete_case:
        tree.delete(i)
        tree.show()
        tree.check()
        assert tree.search(i) is None


def test_tree():
    log('test_tree')
    for i in range(2, 10):
        for j in range(10, 500, 10):
            test_tree_case(i, j, j // 2)


def main():
    test_tree()


if __name__ == '__main__':
    main()

参考

b+树可视化 实现细节与我的代码不同

你可能感兴趣的:(b+树删除操作)