Python版B树

话说以前的树都用java写的,最近发现python有点生疏了,于是用python写了个B树实现,B树在索引领域用得还是蛮多了,如果没记错mysql的默认索引好像就是B树...

首先是数据实体对象,很简单,只存放key,value
class Entity(object):
    '''数据实体'''
    
    def __init__(self,key,value):
        self.key = key
        self.value = value



然后节点对象
class Node(object):
    '''B树的节点'''
    
    def __init__(self):
        self.parent = None
        self.entitys = []
        self.childs = []
        
    def find(self,key):
        '''通过key查找并返回一个数据实体'''
        
        for e in self.entitys:
            if key == e.key:
                return e
            
    
    def delete(self,key):
        '''通过key删除一个数据实体,并返回它和它的下标(下标,实体)'''
        for i,e in enumerate(self.entitys):
            if e.key == key:
                del self.entitys[i]
                return (i,e)
            
            
    def isLeaf(self):
        '''判断该节点是否是一个叶子节点'''
        
        return len(self.childs) == 0
    
    
    def addEntity(self,entity):
        '''添加一个数据实体'''
        
        self.entitys.append(entity)
        self.entitys.sort(key=lambda x:x.key)
        
        
    def addChild(self,node):
        '''添加一个子节点'''

        self.childs.append(node)
        node.parent = self
        self.childs.sort(key=lambda x:x.entitys[0].key)



最后是Tree类
class Tree(object):
    '''B树'''
    
    def __init__(self,size=6):
        self.size = size
        self.root = None
        self.length = 0
        
    
    def add(self,key,value=None):
        '''插入一条数据到B树'''
        
        self.length += 1
        
        if self.root:
            current = self.root
            
            while not current.isLeaf():
                for i,e in enumerate(current.entitys):
                    if e.key > key:
                        current = current.childs[i]
                        break
                    elif e.key == key:
                        e.value = value
                        self.length -= 1
                        return
                else:
                    current = current.childs[-1]
                    
            current.addEntity(Entity(key,value))
                
            if len(current.entitys) > self.size:
                self.__spilt(current)
        else:
            self.root = Node()
            self.root.addEntity(Entity(key,value))
    
    
    def get(self,key):
        '''通过key查询一个数据'''
        
        node = self.__findNode(key)
        
        if node:  
            return node.find(key).value
            
            
    def delete(self,key):
        '''通过key删除一个数据项并返回它'''
        
        node = self.__findNode(key)
        
        if node:
            i,e = node.delete(key)
            
            #在节点不是叶子节点时需要做修复(取对应下标的子节点的最大的一个数据项来补)
            if not node.isLeaf():
                child = node.childs[i]
                j,entity = child.delete(child.entitys[-1].key)
                node.addEntity(entity)
                
                while not child.isLeaf():
                    node = child
                    child = child.childs[j]
                    j,entity = child.delete(child.entitys[-1].key)
                    node.addEntity(entity)
            
            self.length -= 1
            return e.value
    
    
    def isEmpty(self):
        return self.length == 0
    
                
    def __findNode(self, key):
        '''通过key值查询一个数据在哪个节点,找到就返回该节点'''
        
        if self.root:
            current = self.root
            
            while not current.isLeaf():
                for i, e in enumerate(current.entitys):
                    if e.key > key:
                        current = current.childs[i]
                        break
                    elif e.key == key:
                        return current
                else:
                    current = current.childs[-1]
            
            if current.find(key):
                return current
            
            
    def __spilt(self,node):
        '''
        分裂一个节点,规则为:
        1、中间的数据项移到父节点
        2、新建一个右兄弟节点,将中间节点右边的数据项移到新节点
        '''
        
        middle = len(node.entitys) / 2
        
        top = node.entitys[middle]
        
        right = Node()
        
        for e in node.entitys[middle + 1:]:
            right.addEntity(e)
            
        for n in node.childs[middle + 1:]:
            right.addChild(n)
        
        node.entitys = node.entitys[:middle]
        node.childs = node.childs[:middle + 1]
        
        parent = node.parent
        
        if parent:
            parent.addEntity(top)
            parent.addChild(right)
            
            if len(parent.entitys) > self.size:
                self.__spilt(parent)
        else:
            self.root = Node()
            self.root.addEntity(top)
            self.root.addChild(node)
            self.root.addChild(right)



测试代码
if __name__ == '__main__':
    t = Tree(4)
    t.add(20)
    t.add(40)
    t.add(60)
    t.add(70,'c') 
    t.add(80)       
    t.add(10) 
    t.add(30)
    t.add(15,'python')
    t.add(75,'java')
    t.add(85)
    t.add(90)
    t.add(25)
    t.add(35,'c#')
    t.add(50)
    t.add(22,'c++')
    t.add(27)
    t.add(32)
    
    print t.get(15)
    print t.get(75)
    print t.delete(35)
    print t.delete(22)
    t.add(22,'lua')
    print t.get(22)
    print t.length    

你可能感兴趣的:(python)