python 可视化节点关系(一):networkx

前言

工作需要将各个类之间的关系用网络图描述出来。
查阅相关资料,主要有如下方式:

  • networkx
  • qtgraph
  • matplotlib

一、networkx

networkx是用Python语言开发的图论与复杂网络建模工具,内置了常用的图与复杂网络分析算法,可以方便的进行复杂网络数据分析、仿真建模等工作。

本文主要实现用networkx画有向图,检测是否有回环,每个节点的前节点、后节点。
本文这里已经封装好了相关的实现类。

# -*- coding:utf-8 -*-

import networkx as nx
import matplotlib.pyplot as plt
import copy
from networkx.algorithms.cycles import *


class GetGraph:

    def __init__(self):
        pass

    @staticmethod
    def create_directed_graph(data_dict):
        my_graph = nx.DiGraph()
        my_graph.clear()
        for front_node, back_node_list in data_dict.items():
            if back_node_list:
                for back_node in back_node_list:
                    my_graph.add_edge(front_node, back_node)
            else:
                my_graph.add_node(front_node)
        return my_graph

    @staticmethod
    def draw_directed_graph(my_graph, name='out'):
        nx.draw_networkx(my_graph, pos=nx.circular_layout(my_graph), vmin=10,
                         vmax=20, width=2, font_size=8, edge_color='black')
        picture_name = name + ".png"
        plt.savefig(picture_name)
        # print('save success: ', picture_name)
        # plt.show()

    @staticmethod
    def get_next_node(my_graph):
        nodes = my_graph.nodes
        next_node_dict = {}
        for n in nodes:
            value_list = list(my_graph.successors(n))
            next_node_dict[n] = value_list
        return copy.deepcopy(next_node_dict)

    @staticmethod
    def get_front_node(my_graph):
        nodes = my_graph.nodes
        front_node_dict = {}
        for n in nodes:
            value_list = list(my_graph.predecessors(n))
            front_node_dict[n] = value_list
        return copy.deepcopy(front_node_dict)

    @staticmethod
    def get_loop_node(my_graph):
        loop = (list(simple_cycles(my_graph)))
        return copy.deepcopy(loop)


if __name__ == '__main__':
    comp_graph_object = GetGraph()
    comp_statement = {'CT_UPDATE_POS': ['CT_STATE'], 'CT_STATE': ['CT_MOVE'], 'CT_MOVE': ['CT_STATE', 'CT_FLUSH_VISUAL'],
      'CT_VISUAL': [], 'CT_FLUSH_VISUAL': ['CT_MOVE'], 'CT_INPUT': []}

    print('self.comp_statement_ct_map:', comp_statement)
    graph = comp_graph_object.create_directed_graph(comp_statement)
    comp_next_node = comp_graph_object.get_next_node(graph)
    comp_front_node = comp_graph_object.get_front_node(graph)
    comp_loop_list = comp_graph_object.get_loop_node(graph)
    comp_graph_object.draw_directed_graph(graph)

如果有警告提示:

FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  if self._edgecolors == str('face'):

有博客说是源码问题,无需理会。
效果图为:


注意点

  • 节点的位置排列官方给了几种办法,选用合适的即可。

    本文这里使用的是将节点画在了同心圆上,这样节点之间交叉较少。
    查阅源码可知,nx.circular_layout(my_graph)虽然有设置高维(3维、4维...)的参数,但是还是只能实现画在2维平面。思路主要是将1等距离平分n份,然后变换成2π的角度,圆半径已知,求得圆周上各节点之间的位置。
    但是源码的所有节点只能分布在同一圆上,如果节点很多,便不再适用,因此本文在此基础上改为分布在同心圆上。
import  matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt


def get_node_pos(node_list, radius=1, step=1, step_num=8, center=(0, 0), dim=2):
    if dim < 2:
        raise ValueError('cannot handle dimensions < 2')
    paddims = max(0, (dim - 2))
    odd_all_num = len(node_list)
    node_pos_list = []
    while odd_all_num > 0:
        cur_lever_num = radius * step_num
        if odd_all_num < cur_lever_num:
            cur_lever_num = odd_all_num
        odd_all_num -= cur_lever_num

        theta = np.linspace(0, 1, cur_lever_num + 1)[:-1] * 2 * np.pi
        theta = theta.astype(np.float32)
        pos = np.column_stack([np.cos(theta) * radius, np.sin(theta) * radius,
                               np.zeros((cur_lever_num, paddims))])
        pos = pos.tolist()
        node_pos_list.extend(pos)
        radius += 1
    all_pos = dict(zip(node_list, node_pos_list))
    return all_pos


if __name__ == '__main__':
    node =range(1,30,1)
    print('node:', node)
    pos = get_node_pos(node)

    # fig, ax = plt.subplots(figsize=(10,10))
    for name, pos in pos.items():
        plt.scatter(pos[0], pos[1])

思路主要是:圆的半径是radius ,圆上最大节点数为step_num,当节点数超过step_num,
求出新圆的半径radius+=step,该圆的最大节点数为radius*step_num。也就是说,圆的半径和最大节点数成正比。默认参数为,从里到外圆上最大节点数依次为8个、16个、8n个。

这里用matplotlib来显示效果:


但是发现,明明等分的圆,这些点却明显是椭圆。纠结了好久,才发现原来matplotlib默认得到的宽、高是不等距离的。因此设置图像的宽和高相等即可。

fig, ax = plt.subplots(figsize=(10,10))

最终效果图为:



可以看出,我们将节点等分在了同心圆上。

问题:

  • 虽然nerworkx很方便也很强大,但是发现networkx的节点名字和节点遮挡了,官方api感觉没怎么细讲,时间关系也没有深究
  • 箭头有点丑
  • 关键是工作需要鼠标和图像交互,比如高亮节点等。networkx这时便有点不够用了,或者我没找到合适的方法。因此本文用networkx得到节点之间的关系,画图则改用其他库,实现鼠标交互。

参考

  • Tutorial — NetworkX 1.10 documentation
    https://networkx.github.io/documentation/networkx-1.10/tutorial/index.html

  • python networkx 包绘制复杂网络关系图 -
    https://www.jianshu.com/p/e543dc63454f

你可能感兴趣的:(python 可视化节点关系(一):networkx)