图形化输出二叉树(Python实现)

目录

  • 目录
  • 1. 功能介绍
  • 2. 参数说明
  • 3. 应用实例
  • 4. 原理简述
  • 5. 源码

1. 功能介绍

以图形化的形式输出二叉树,可以清晰地呈现出二叉树的内部结构。原创的PrintBT模块提供了相应的接口,可以轻松实现上述效果。
假设已经建立了一棵二叉树:

T=BinaryTree()
T.create('ABD###CEG###F##')

通过调用PrintBT模块:

from PrintBT import PrintBT
PrintBT(T)

可以输出:

    A
    |
  -----
  |   |
  B   C
  |   |
--- -----
|   |   |
D   E   F
    |
  ---
  |
  G

2. 参数说明

目前,PrintBT函数共有4个参数:

def __init__(self,T,margin=0,padding=4,data_field='data'):
    pass
  • T:已建立的一棵二叉树。节点的左右指针分别为lcrc,至少存在一个数据域。注:如果左右指针不是lcrc,程序会出错。为了增加适用性,可以参照data_field 的实现方式,增加参数,提供更多配置的选择。
  • margin:输出树时,整棵树与屏幕左侧的距离。默认为0,即紧靠左侧。
  • padding:控制输出时左右子树之间的间距,最好大于1。默认为4。
  • data_field:控制数据域的名称。默认为data。如果数据域不是data,或者需要输出其他的数据域,则可以调整这个参数。

3. 应用实例

下面以AVL树为例,动态展现AVL树在构建过程中节点以及相应平衡因子的变化情况。节点的插入顺序为:25,27,30,12,11,18,14,20,15,22

from PrintBT import PrintBT

class TreeNode():
    def __init__(self,data,parent=None):
        self.data=data
        self.parent=parent
        self.lc=None
        self.rc=None
        self.bf=0 #平衡因子

class BinaryTree():
    def __init__(self,L):
        self.root=None
        if not L:
            return
        cnt=0
        for x in L:
            self.insert(x) #将元素逐个插入到AVL中,以此建立AVL
            cnt+=1
            #输出AVL树和相应的BF
            print('-'*6+"Step:"+str(cnt)+": Insert "+str(x)+'-'*6+'\n')
            PrintBT(self,5)
            print('\n')
            PrintBT(self,5,data_field='bf')
            print('\n',end='')

    def insert(self,x):
        if self.root is None:
            self.root=TreeNode(x) #单独处理根节点
        else:
            p=self.root #当前节点
            pr=None #当前节点的父节点
            tag=None #标记当前节点是其父节点的左孩子还是右孩子
            while p:
                if p.data==x:
                    return False #如果AVL中存在该数据,则不插入
                pr=p
                if p.data>x:
                    tag='l'
                    p=p.lc
                else:
                    tag='r'
                    p=p.rc
            p=TreeNode(x,pr)
            if tag=='l':  #插入新的节点
                pr.lc=p
            else:
                pr.rc=p
            ubn,mode=self.__modify_bf(pr,tag) #ubn:unbalanced node,自下而上第一个不平衡的点
            if ubn:
                # print(ubn.data,mode[-2:][::-1])
                self.__reBalanced(ubn,mode[-2:][::-1])
        return True

    def __modify_bf(self,q,tag):
        ubn=None #第一个不平衡的节点
        mode=''
        while q:
            if tag=='l':
                q.bf+=1
            else:
                q.bf-=1
            if q.bf==0:
                break
            elif q.bf==1 or q.bf==-1:
                    mode+=tag
                    if q.parent is None:
                        break #如果是根节点则直接退出
                    if q.parent.lc==q:
                        tag='l'
                    else:
                        tag='r'
                    q=q.parent
            else:
                ubn=q
                mode+=tag
                break
        return ubn,mode

    def __reBalanced(self,p,mode):
        if mode=='ll':
            self.__R_rotate(p.parent,p,p.lc)
            p.bf=p.parent.bf=0  #修改平衡因子
        elif mode=='rr':
            self.__L_rotate(p.parent,p,p.rc)
            p.bf=p.parent.bf=0
        elif mode=='lr':
            self.__L_rotate(p,p.lc,p.lc.rc)
            self.__R_rotate(p.parent,p,p.lc)
            if p.parent.bf==1: #说明新插入的节点插在上一步AVL树叶节点的左子树
                p.parent.lc.bf=0
                p.bf=-1
            else:
                p.parent.lc.bf=1
                p.bf=0
            p.parent.bf=0
        else:
            self.__R_rotate(p,p.rc,p.rc.lc)
            self.__L_rotate(p.parent,p,p.rc)
            if p.parent.bf==1:
                p.bf=0
                p.parent.rc.bf=-1
            else:
                p.bf=1
                p.parent.rc.bf=0
            p.parent.bf=0

    def __R_rotate(self,x,p,c):
        '''p是当前节点,x是p的父节点,c是p的孩子节点'''
        p.lc=c.rc
        if c.rc:
            c.rc.parent=p #修改父节点
        c.rc=p
        p.parent=c #修改父指针的指向
        if x:
            if x.lc==p:
                x.lc=c
            else:
                x.rc=c
            c.parent=x
        else:
            self.root=c
            c.parent=None

    def __L_rotate(self,x,p,c):
        p.rc=c.lc
        if c.lc:
            c.lc.parent=p
        c.lc=p
        p.parent=c
        if x:
            if x.lc==p:
                x.lc=c
            else:
                x.rc=c
            c.parent=x
        else:
            self.root=c
            c.parent=None

if __name__=='__main__':
    L=[25,27,30,12,11,18,14,20,15,22]
    T=BinaryTree(L)

输出结果为:

------Step:1: Insert 25------

    25


     0

------Step:2: Insert 27------

    25
     |
     ---
       |
      27


    -1
     |
     ---
       |
       0

------Step:3: Insert 30------

      27
       |
     -----
     |   |
    25  30


       0
       |
     -----
     |   |
     0   0

------Step:4: Insert 12------

        27
         |
       -----
       |   |
      25  30
       |
     ---
     |
    12


         1
         |
       -----
       |   |
       1   0
       |
     ---
     |
     0

------Step:5: Insert 11------

        27
         |
       -----
       |   |
      12  30
       |
     -----
     |   |
    11  25


         1
         |
       -----
       |   |
       0   0
       |
     -----
     |   |
     0   0

------Step:6: Insert 18------

        25
         |
       -----
       |   |
      12  27
       |   |
     ----- ---
     |   |   |
    11  18  30


         0
         |
       -----
       |   |
       0  -1
       |   |
     ----- ---
     |   |   |
     0   0   0

------Step:7: Insert 14------

        25
         |
       -----
       |   |
      12  27
       |   |
     ----- ---
     |   |   |
    11  18  30
         |
       ---
       |
      14


         1
         |
       -----
       |   |
      -1  -1
       |   |
     ----- ---
     |   |   |
     0   1   0
         |
       ---
       |
       0

------Step:8: Insert 20------

        25
         |
       -----
       |   |
      12  27
       |   |
     ----- ---
     |   |   |
    11  18  30
         |
       -----
       |   |
      14  20


         1
         |
       -----
       |   |
      -1  -1
       |   |
     ----- ---
     |   |   |
     0   0   0
         |
       -----
       |   |
       0   0

------Step:9: Insert 15------

          25
           |
         -----
         |   |
        14  27
         |   |
       ----- ---
       |   |   |
      12  18  30
       |   |
     --- -----
     |   |   |
    11  15  20


           1
           |
         -----
         |   |
         0  -1
         |   |
       ----- ---
       |   |   |
       1   0   0
       |   |
     --- -----
     |   |   |
     0   0   0

------Step:10: Insert 22------

            18
             |
         ---------
         |       |
        14      25
         |       |
       -----   -----
       |   |   |   |
      12  15  20  27
       |       |   |
     ---       --- ---
     |           |   |
    11          22  30


             0
             |
         ---------
         |       |
         1       0
         |       |
       -----   -----
       |   |   |   |
       1   0  -1  -1
       |       |   |
     ---       --- ---
     |           |   |
     0           0   0

4. 原理简述

PrintBT模块中的PrintBT是一个类,其中主要有4个方法。
__init_tree

本质上,PrintBT也是一棵二叉树,只不过比一般的二叉树多了parent 域和pos 域。该函数将作为参数传入的二叉树T复制给PrintBT二叉树,同时计算节点的相对位置pos。根节点的pos初始化为0,此后,假设父节点的pos为p,则其左孩子的pos为p-2,右孩子的pos为p+2,以此类推。

__pos_adjust

然而,上一步计算的pos会存在冲突。例如:

      A
      |
  ---------
  |       |
  B       E
  |       |
-----   -----
|   |   |   |
C   D   F   G

节点D和F的pos都是0,显然两者不能占据同一个位置,因此要增加F的pos值,这就与参数padding有关。仅仅增加F的pos是不够的,必须把以E为根节点的子树同步右移。__pos_adjust 可以实现上述pos值的调整,从而消除冲突。

__adjust_parent

在上述调整的过程中,可能右子树向右移动了一定位置,而父节点还停留在原来的位置。为了更加美观,__adjust_parent 可以将父节点移到左右子树中间的位置。

__find_min_pos

__find_min_pos 可以找到各节点最小的pos值。其他节点的pos减去这个最小的pos值就可以得出节点距最左边一个节点的相对位置。

__printTree

最后依据各节点的pos值计算每一行空格和"-"的个数,输出整棵二叉树。

5. 源码

#文件命名为:PrintBT.py 
from collections import deque

class TreeNode():
    def __init__(self,data):
        self.data=data
        self.lc=None
        self.rc=None
        self.parent=None
        self.pos=0 #节点的相对位置

class PrintBT():
    '''
    #引入方式:
    from PrintBT import PrintBT
    PrintBT(T)

    #参数含义:
    margin:图形整体距屏幕左边的距离;
    padding:两株子树之间的距离,最好大于1
    data_field:数据域的名称,默认从"data"中获取数据。
                如果希望获取其他域的数据,如AVL的bf域,则可传入参数。
    '''
    def __init__(self,T,margin=0,padding=4,data_field='data'):
        self.root=None
        padding=2 if padding<2 else padding
        if T.root is None:  #T中必须保证根节点命名为root,且允许外部访问。
            return          #否则可以设置getRoot()函数
        max_data_length=self.__init_tree(T,data_field) #复制T中的数据,同时增加parent域和pos域
        max_data_length=max_data_length//2
        margin=max_data_length if marginelse margin
        self.__pos_adjust(padding) #修正节点的位置
        self.__adjust_parent() #修正父节点的位置
        p=self.__find_min_pos() #找最左边的节点
        self.__printTree(p-margin) #以图形方式打印节点

    def __init_tree(self,T,data_field):
        #单独处理根节点
        self.root=TreeNode(T.root.__dict__[data_field])
        Q=deque() #被复制的T的根节点
        Q.append(T.root)
        max_data_length=1 #数据域中最长的字符长度
        P=deque() #当前树的根节点
        P.append(self.root)
        while Q:
            x=Q.popleft()
            y=P.popleft()
            if len(str(x.__dict__[data_field]))>max_data_length:
                max_data_length=len(str(x.__dict__[data_field]))
            if x.lc:
                Q.append(x.lc)
                y.lc=TreeNode(x.lc.__dict__[data_field])
                y.lc.parent=y
                y.lc.pos=y.pos-2
                P.append(y.lc)
            if x.rc:
                Q.append(x.rc)
                y.rc=TreeNode(x.rc.__dict__[data_field])
                y.rc.parent=y
                y.rc.pos=y.pos+2
                P.append(y.rc)
        return max_data_length

    def __find_parent(self,node):  #沿着右斜上方的方向找父节点
        N=node
        while N.parent and N.parent.lc==N:
            N=N.parent
        return N

    def __move_tree(self,node,num):
        if not node:
            return
        Q=deque()
        Q.append(node)
        while Q:
            N=Q.popleft()
            N.pos+=num
            if N.lc:
                Q.append(N.lc)
            if N.rc:
                Q.append(N.rc)

    def __pos_adjust(self,padding):
        L=[]
        if self.root.lc:
            L.append(self.root.lc)
        if self.root.rc:
            L.append(self.root.rc)
        while L:
            temp=[]
            i=1
            if len(L)>1:
                while i1].data))
                    offset=L[i-1].pos+(data_len-data_len//2)
                    if paddingif L[i].pos<=offset:
                        parentNode=self.__find_parent(L[i])  #找到父节点
                        self.__move_tree(parentNode,L[i-1].pos-L[i].pos+padding) #取大于1的值即可
                    i+=1
            for node in L:
                if node.lc:
                    temp.append(node.lc)
                if node.rc:
                    temp.append(node.rc)
            L=temp

    def __find_min_pos(self):
        min_pos=0
        Q=deque()
        Q.append(self.root)
        while Q:
            node=Q.popleft()
            if node.posif node.lc:
                Q.append(node.lc)
            if node.rc:
                Q.append(node.rc)
        return min_pos

    def __printTree(self,p):
        L=[self.root]
        line=''
        while L:
            temp=[]
            cur=p
            if L[0]!=self.root:  #排除根节点
                for node in L:
                    line+=' '*(node.pos-cur)
                    line+='|'
                    cur=node.pos+1
                line+='\n'
                cur=p
            offset_r=0 #上一个节点右边的偏移量     
            for node in L:
                offset_l=len(str(node.data))//2 #左边的偏移量     
                line+=' '*(node.pos-cur-offset_l-offset_r)
                line+=str(node.data)
                offset_r=len(str(node.data))-offset_l-1 #本节点右边的偏移量
                cur=node.pos+1
            line+='\n'
            cur=p
            for node in L:
                if node.lc or node.rc:
                    line+=' '*(node.pos-cur)  
                    line+='|'
                    cur=node.pos+1
            line+='\n'
            cur=p
            for node in L:
                if node.lc and node.rc:
                    line+=' '*(node.lc.pos-cur)
                    line+='-'*(node.rc.pos-node.lc.pos+1)
                    cur=node.rc.pos+1
                    temp.append(node.lc)
                    temp.append(node.rc)
                elif node.lc and not node.rc:
                    line+=' '*(node.lc.pos-cur)
                    line+='-'*(node.pos-node.lc.pos+1)
                    cur=node.pos+1
                    temp.append(node.lc)
                elif node.rc and not node.lc:
                    line+=' '*(node.pos-cur)
                    line+='-'*(node.rc.pos-node.pos+1)
                    cur=node.rc.pos+1
                    temp.append(node.rc)            
            L=temp
            line+='\n'
        print(line.rstrip())

    def __adjust_parent(self):
        L=[self.root]
        P=deque() #栈
        P.append(L)
        while L:
            temp=[]
            for node in L:
                if node.lc:
                    temp.append(node.lc)
                if node.rc:
                    temp.append(node.rc)
            L=temp
            P.append(L)  #把节点一层一层的存入栈中
        while P:
            L=P.pop()
            i=0
            while i1:
                if L[i].parent==L[i+1].parent:
                    L[i].parent.pos=(L[i].pos+L[i+1].pos)//2
                    i+=2
                else:
                    i+=1

你可能感兴趣的:(Python,数据结构)