网络可视化的重要性不言而喻,我们可以利用图的形式更好的观察网络的结构。记得以前写过一篇关于Caffe下的网络的可视化的博客点击打开链接,其实有一个在线的Caffe网络结构工具:
http://ethereon.github.io/netscope/#/editor。
那么在Mxnet下怎么画出网络结构图呢?其实过程还是很简单的,因为Mxnet已经为我们写好了函数,我们只要简单调用即可。废话不说了,上代码:
import mxnet as mx
import numpy as np
import cv2
import matplotlib.pyplot as plt
import logging
#启动日志
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
#定义一个网络
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=data,name='fc1',num_hidden=128)
act1 = mx.symbol.Activation(data=fc1,name='relu1',act_type='relu')
fc2 = mx.symbol.FullyConnected(data=act1,name='fc2',num_hidden=64)
act2 = mx.symbol.Activation(data=fc2,name='relu2',act_type='relu')
fc3 = mx.symbol.FullyConnected(data=act2,name='fc3',num_hidden=10)
mlp = mx.symbol.SoftmaxOutput(data=fc3,name='softmax')
#可视化网络
mx.viz.plot_network(mlp).view()
友情提示:
1、mx.viz.plot_network(mlp).view(),必须要调用view()函数,否则无法得到正确结果
2、要确保你已经安装了graphviz和pydot,如果没有安装可以使用pip install进行安装
3、如果您遇到了下面的错误提示:"RuntimeError: failed to execute ['dot', '-Tpdf', '-O', 'test-output/round-table.gv'],make sure the Graphviz executables are on your systems' path"
是因为你没有将graphviz加入系统环境变量,我的解决方法是,我首先在windows下安装了
graphviz,注意不是用pip intall安装的python版本,然后将
D:\Program Files (x86)\Graphviz2.39\bin,加入系统环境变量。重新启动系统即可!
可参考:
http://stackoverflow.com/questions/28312534/graphvizs-executables-are-not-found-python-3-4
补充:
原来mx.viz.plot_network()函数里还有其他参数,下面分别贴出代码和效果图
batch_size = 100
data_shape = (batch_size, 784)
mx.viz.plot_network(softmax, shape={"data":data_shape}, node_attrs={"shape":'oval',"fixedsize":'false'})
效果如左图
batch_size = 100
data_shape = (batch_size, 784)
mx.viz.plot_network(softmax)#没有shape参数了
效果如右图
看出不同了吗?我觉得这应该和graphviz有关,shape参数可以在箭头上显示输出个数,
node_attrs参数可以设定节点的属性,比如节点形状等。
祝各位好运!