from collections import Iterable
import networkx as nx
import matplotlib.pyplot as plt
class Node(object):
def __init__(self, elem=-1, left=None, right=None):
self.elem = elem
self.left = left
self.right = right
class Tree(object):
def __init__(self, seq=()):
assert isinstance(seq, Iterable)
self.root = None
self.add(*seq)
def add(self, *args):
if not args:
return
if not self.root:
self.root = Node(args[0])
args = args[1:]
for i in args:
node = Node(i)
current_node = self.root
while 1:
if i > current_node.elem:
if not current_node.right:
current_node.right = node
break
else:
current_node = current_node.right
else:
if not current_node.left:
current_node.left = node
break
else:
current_node = current_node.left
def front_digui(self, root):
'''
先(根)序遍历: 先处理根,之后是左子树,然后是右子树
:param root:
:return:
'''
if root == None:
return
print(root.elem)
self.front_digui(root.left)
self.front_digui(root.right)
def middle_digui(self, root):
'''
中(根)序遍历: 先处理左子树,之后是根,最后是右子树
:param root:
:return:
'''
if root == None:
return
self.middle_digui(root.left)
print(root.elem)
self.middle_digui(root.right)
def later_digui(self, root):
'''
后(根)序遍历: 先处理左子树,之后是右子树,最后是根
:param root:
:return:
'''
if root == None:
return
self.later_digui(root.left)
self.later_digui(root.right)
print(root.elem)
def height(self, root):
if root == None:
return 0
return max(self.height(root.left)+1, self.height(root.right)+1)
# 树可视化
def create_graph(G, node, pos={}, x=0, y=0, layer=1):
pos[node.elem] = (x, y)
if node.left:
G.add_edge(node.elem, node.left.elem)
l_x, l_y = x - 1 / 2 ** layer, y - 1
l_layer = layer + 1
create_graph(G, node.left, x=l_x, y=l_y, pos=pos, layer=l_layer)
if node.right:
G.add_edge(node.elem, node.right.elem)
r_x, r_y = x + 1 / 2 ** layer, y - 1
r_layer = layer + 1
create_graph(G, node.right, x=r_x, y=r_y, pos=pos, layer=r_layer)
return (G, pos)
def draw(node): # 以某个节点为根画图
graph = nx.DiGraph()
graph, pos = create_graph(graph, node)
fig, ax = plt.subplots(figsize=(8, 10)) # 比例可以根据树的深度适当调节
nx.draw_networkx(graph, pos, ax=ax, node_size=300)
plt.show()
if __name__ == '__main__':
l = [30, 40, 50, 60, 80 ,90 ,100 ,32, 65, 42, 12, 20]
tree = Tree(l)
print(tree)
print(tree.root)
tree.front_digui(tree.root)
print("*" * 200)
tree.middle_digui(tree.root)
print("*"*200)
print(tree.height(tree.root))
print("*" * 200)
tree.middle_digui(tree.root)
draw(tree.root)