from keras.utils import plot_model
plot_model(model, './model.bmp', show_shapes=True)
使用keras中的plot_model模块可以绘制网络模型图,但是可能报pydot缺失的错。
pip安装完又报另一个错误`pydot` failed to call GraphViz.
根据提示到相关网站下载对应系统的安装包吧,这里下的是window版本的msi安装包
安装完后发现报错还是没有解决,仔细检查报错,会发现这里是dot;
此时应该去site-packages路径下的pydot.py中将self.prog = 'dot'修改为self.prog = 'dot.exe'(大概1710行数)
另外还得为刚刚安装的GraphViz添加环境变量,可以在系统设置,也可以代码中添加
import os
os.environ["PATH"] += ";D:/Program/Graphviz2.38/bin/"
再次执行如下完整代码
from __future__ import absolute_import, division, print_function
import os
import tensorflow as tf
from tensorflow import keras
from keras.utils import plot_model
print('tf version: {}'.format(tf.__version__))
# Returns a short sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return model
# Create a basic model instance
model = create_model()
os.environ["PATH"] += ";D:/Program/Graphviz2.38/bin/"
plot_model(model, './model.bmp', show_shapes=True)
就可以得到网络的模型图