pytorchviz进行pytorch执行过程的可视化

torchviz:https://github.com/szagoruyko/pytorchviz
转载于:使用pytorchviz进行pytorch执行过程的可视化 - pytorch中文网

1. 安装

pip install graphviz
pip install git+https://github.com/szagoruyko/pytorchviz
import torch
from torch.autograd import Variable
from torch import nn
from torchviz import make_dot, make_dot_from_trace


model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))

x = Variable(torch.randn(1,8))
y = model(x)

make_dot(y.mean(), params=dict(model.named_parameters()))
  • 主要有两个函数,make_dot可以从任何PyTorch函数(要求至少有一个输入变量requires_grad)中生成图形,并make_dot_from_trace使用输出torch.jit.trace(并不总是有效)。参见examples.ipynb。

你可能感兴趣的:(pytorchviz进行pytorch执行过程的可视化)