PyTorch 靠谱的模型可视化教程

本文不是无脑的复制粘贴,谢谢大家支持

Netron GitHub官网

Netron 画 PyTorch 模型呢,需要用 torch.onnx.export 来导出.onnx 后缀的模型,然后用 Netron 可视化就会得到非常好看的模型架构图。直接用torch.save出来的模型,画出来模型图不对。
安装

pip install netron

使用
第一步,保存模型

import torch

model = Yourmodelclass() # 定义你的模型
x,y = next(iter(yourdataloader)) # 从你的dataloader取出一个batch的数据,或者你用torch.randn生成
model.eval()
pred = model(x)

# 导出为 .onnx 的模型
torch.onnx.export(model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

第二步,用Netron画图

~$ python3
>>> import netron
>>> netron.start('model.onnx')
Serving '/y/home/zdx/model.pth' at http://localhost:8080
('localhost', 8080)
>>> 

打开浏览器,把上面的 http://localhost:8080,用浏览器打开
左上角有设置,可以选择垂直或者水平展示模型
PyTorch 靠谱的模型可视化教程_第1张图片

PyTorchViz GitHub官网

安装

conda install -c anaconda graphviz python-graphviz
conda install pydot
pip install torchviz

使用

import torch

model = Yourmodelclass() # 定义你的模型
x,y = next(iter(yourdataloader)) # 从你的dataloader取出一个batch的数据,或者你用torch.randn生成

model.eval()
pred = model(x)

make_dot(pred.mean(), params=dict(model.named_parameters()))

PyTorch 靠谱的模型可视化教程_第2张图片

你可能感兴趣的:(Python,deep-learning,pytorch)