pip install graphviz
pip install tochviz (或pip install git+https://github.com/szagoruyko/pytorchviz)
Graphviz 是 AT&T 开发的一款开源的图形可视化软件,可以根据dot脚本语言中绘制的无向图(显示了对象间最简单的关系)画出直观的树形图。
Graphviz在Windows中的安装需要下载Release包,并配置环境变量,否则会报错:
graphviz.backend.ExecutableNotFound: failed to execute [‘dot’, ‘-Tpng’, ‘-O’, ‘tmp’], make sure the Graphviz executables are on your systems’ PATH
Graphviz下载地址 https://graphviz.gitlab.io/_pages/Download/Download_windows.html
下载之后解压出来是一个“release”文件夹,把“release\bin”目录添加到系统环境变量,之后在终端中输入“dot -V”,显示以下信息表示Graphviz配置成功:
# Created by 牧野 CSDN
import torch
from torch import nn
from torchviz import make_dot, make_dot_from_trace
model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))
x = torch.randn(1,8)
vis_graph = make_dot(model(x), params=dict(model.named_parameters()))
vis_graph.view() # 会在当前目录下保存一个“Digraph.gv.pdf”文件,并在默认浏览器中打开
with torch.onnx.set_training(model, False):
trace, _ = torch.jit.get_trace_graph(model, args=(x,))
make_dot_from_trace(trace)
调用“make_dot”方法创建一个dot对象,使用“view”方法显示出来。
pytorch1.2和1.3版本中使用“torch.jit.get_trace_graph”可能会报错,1.1版本ok。
AttributeError: 'torch._C.Value' object has no attribute 'uniqueName'
可视化结果:
Netron开源地址: https://github.com/lutzroeder/Netron
Netron的开发者是Lutz Roeder,一位来自微软Visual Studio团队的帅哥:
Netron是一款支持离线查看“各种”神经网络框架的模型可视化神器,其中的“各种”包括:
- ONNX (.onnx, .pb, .pbtxt)
- Keras (.h5, .keras)
- Core ML (.mlmodel)
- Caffe (.caffemodel, .prototxt)
- Caffe2 (predict_net.pb, predict_net.pbtxt)
- MXNet (.model, -symbol.json)
- NCNN (.param)
- TensorFlow Lite (.tflite)
- TorchScript (.pt, .pth)
- PyTorch (.pt, .pth)
- Torch (.t7)
- Arm NN (.armnn)
- BigDL (.bigdl, .model)
- Chainer, (.npz, .h5)
- CNTK (.model, .cntk)
- Deeplearning4j (.zip)
- Darknet (.cfg)
- ML.NET (.zip)
- MNN (.mnn)
- OpenVINO (.xml)
- PaddlePaddle (.zip, __model__)
- scikit-learn (.pkl)
- TensorFlow.js (model.json, .pb)
- TensorFlow (.pb, .meta, .pbtxt)
嗯,够多了。
Netron使用很简单,作者提供了各个平台的安装包,安装之后打开,把保存的模型文件拖入就可以了。
还以上边的模型为例,先把pytorch模型保存出来:
import torch
from torch import nn
from torchviz import make_dot, make_dot_from_trace
model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))
torch.save(model, 'model.pth') # 保存模型
之后用Netron打开保存的“model.pth”:
网络结构很清晰,一目了然,右侧还能显示操作的进一步信息。
如果你懒得安装,还可以使用作者提供的在线Netron查看器,地址:https://lutzroeder.github.io/netron/