pytorch 网络可视化(二):graphviz + torchviz

引导

    • 1. 安装 graphviz 和 torchviz
    • 2. 测试是否安装成功
    • 3. 输出网络结构

1. 安装 graphviz 和 torchviz

首先打开 Anaconda prompt 进入自己的 pytorch 环境(图中 pt 是我自己的 pytorch 环境),运行如下代码安装依赖包。

pip install graphviz torchviz

具体过程如下图所示,其中 pt 是我自己的 pytorch 环境:

pytorch 网络可视化(二):graphviz + torchviz_第1张图片

2. 测试是否安装成功

运行 python 进入交互式环境,导入两个包看是否报错,不报错则安装成功,如下:

pytorch 网络可视化(二):graphviz + torchviz_第2张图片

3. 输出网络结构

打开自己的 python 编辑器,解释器选择自己的 pytorch 环境中的 python.exe ,运行如下代码:

import torch
from torchviz import make_dot
from torchvision.models import vgg16  # 以 vgg16 为例

x = torch.randn(4, 3, 32, 32)  # 随机生成一个张量
model = vgg16()  # 实例化 vgg16,网络可以改成自己的网络
out = model(x)   # 将 x 输入网络
g = make_dot(out)  # 实例化 make_dot
g.view()  # 直接在当前路径下保存 pdf 并打开
# g.render(filename='netStructure/myNetModel', view=False, format='pdf')  # 保存 pdf 到指定路径不打开

此时会报错,报错内容如下:

pytorch 网络可视化(二):graphviz + torchviz_第3张图片
报错内容是说 graphviz 中的一些可执行命令不在系统路径里,分析原因是环境变量里没有与 graphviz 相关的即只安装了包没安装软件,因此继续下载安装 graphviz.exe,下载地址:https://www2.graphviz.org/Packages/stable/windows/10/cmake/Release/x64/

网页内容如下图,点击箭头处直接下载:

在这里插入图片描述
下载下来是一个可执行文件,如下,双击运行进入安装向导:

在这里插入图片描述
安装非常简单,顺序操作就行。中间注意勾选添加环境变量就行,不用勾选桌面快捷方式选项,如下:

在这里插入图片描述
安装完成后检查一下环境变量里是否已经有了,没有的话要自己加上,如图所示我的环境里已经出现了 E:\Graphviz\bin :

在这里插入图片描述
以上步骤进行完后,再执行输出网络结构的代码就会正常执行了,可以自己选择是要直接打开还是只保存一下,代码与前述相同。

import torch
from torchviz import make_dot
from torchvision.models import vgg16  # 以 vgg16 为例

x = torch.randn(4, 3, 32, 32)  # 随机生成一个张量
model = vgg16()  # 实例化 vgg16,网络可以改成自己的网络
out = model(x)   # 将 x 输入网络
g = make_dot(out)  # 实例化 make_dot
g.view()  # 直接在当前路径下保存 pdf 并打开
# g.render(filename='netStructure/myNetModel', view=False, format='pdf')  # 保存 pdf 到指定路径不打开

最后输出的 vgg16 网络结构如下所示,图片有点大哈哈:

pytorch 网络可视化(二):graphviz + torchviz_第4张图片
可以看出,在输出的网络结构里不但有计算路径,网络各层的权重和偏移量等都体现出来了,非常利于学习和理解网络的整体结构和功能。

你可能感兴趣的:(pytorch之网络可视化,深度学习,pytorch,神经网络,机器学习)