Netron 画 PyTorch 模型呢,需要用 torch.onnx.export 来导出.onnx 后缀的模型,然后用 Netron 可视化就会得到非常好看的模型架构图。直接用torch.save出来的模型,画出来模型图不对。
安装
pip install netron
使用
第一步,保存模型
import torch
model = Yourmodelclass() # 定义你的模型
x,y = next(iter(yourdataloader)) # 从你的dataloader取出一个batch的数据,或者你用torch.randn生成
model.eval()
pred = model(x)
# 导出为 .onnx 的模型
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
"model.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
第二步,用Netron画图
~$ python3
>>> import netron
>>> netron.start('model.onnx')
Serving '/y/home/zdx/model.pth' at http://localhost:8080
('localhost', 8080)
>>>
打开浏览器,把上面的 http://localhost:8080,用浏览器打开
左上角有设置,可以选择垂直或者水平展示模型
安装
conda install -c anaconda graphviz python-graphviz
conda install pydot
pip install torchviz
使用
import torch
model = Yourmodelclass() # 定义你的模型
x,y = next(iter(yourdataloader)) # 从你的dataloader取出一个batch的数据,或者你用torch.randn生成
model.eval()
pred = model(x)
make_dot(pred.mean(), params=dict(model.named_parameters()))