【Pytorch】使用tensorwatch神经网络结构可视化

tensorwatch神经网络结构可视化

  • 提前说明
  • 环境安装
  • 测试
  • 结果

提前说明

版本问题一定要注意
python = 3.6.5
pytorch = 1.2.0
torchvision = 0.4.0
tensorwatch = 0.8.7
pydot = 1.4.2
scikit-learn = 0.24.2
pandas = 1.1.5

环境安装

首先使用conda 创建一个python3.6的新环境

conda create -n plotnet python=3.6.5

pip 安装ipykernel

pip install ipykernel
python -m ipykernel install  --name plotnet --display-name "plotnet"

--name plotnet 后面跟的是刚刚创建的环境名称,--display-name "plotnet"后面是jupyter里显示的核的名字
pip install graphviz 或者 conda install graphviz安装graphviz

pip install torch==1.2.0+cpu torchvision==0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html -i https://pypi.douban.com/simple
pip install tensorwatch == 0.8.7
pip install pydot = 1.4.2
pip install scikit-learn = 0.24.2
pip install pandas = 1.1.5

测试

from torchvision.models import vgg16  # 以 vgg16 为例
from tensorwatch import draw_model
mynet = vgg16()  
draw_model(mynet, [1, 3, 128, 128])  # 输出网络结构图

此时会报错 'Dot' object has no attribute '_repr_svg_',解决方法,找到文件 pytorch_draw_model.py,将其中第 13 行代码改为 return self.dot.create_svg().decode()
【Pytorch】使用tensorwatch神经网络结构可视化_第1张图片如果有报错FileNotFoundError: [WinError 2] “dot” not found in path.
解决方法如下:
self.prog='dot'改为self.prog='dot.exe'
【Pytorch】使用tensorwatch神经网络结构可视化_第2张图片

结果

【Pytorch】使用tensorwatch神经网络结构可视化_第3张图片

你可能感兴趣的:(python,pytorch,神经网络,python)