版本问题一定要注意
python = 3.6.5
pytorch = 1.2.0
torchvision = 0.4.0
tensorwatch = 0.8.7
pydot = 1.4.2
scikit-learn = 0.24.2
pandas = 1.1.5
首先使用conda 创建一个python3.6的新环境
conda create -n plotnet python=3.6.5
pip 安装ipykernel
pip install ipykernel
python -m ipykernel install --name plotnet --display-name "plotnet"
--name plotnet
后面跟的是刚刚创建的环境名称,--display-name "plotnet"
后面是jupyter里显示的核的名字
pip install graphviz
或者 conda install graphviz
安装graphviz
pip install torch==1.2.0+cpu torchvision==0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html -i https://pypi.douban.com/simple
pip install tensorwatch == 0.8.7
pip install pydot = 1.4.2
pip install scikit-learn = 0.24.2
pip install pandas = 1.1.5
from torchvision.models import vgg16 # 以 vgg16 为例
from tensorwatch import draw_model
mynet = vgg16()
draw_model(mynet, [1, 3, 128, 128]) # 输出网络结构图
此时会报错 'Dot' object has no attribute '_repr_svg_'
,解决方法,找到文件 pytorch_draw_model.py
,将其中第 13 行代码改为 return self.dot.create_svg().decode()
如果有报错FileNotFoundError: [WinError 2] “dot” not found in path.
,
解决方法如下:
将self.prog='dot'
改为self.prog='dot.exe'