关于keras 中输出模型网络图的说明-how to plot the lstm model graph and save it to a file

在keras中,可以通过Model visualization模块进行作图。以下是来自keras文档的说明。

Model visualization

Keras provides utility functions to plot a Keras model (using graphviz).

This will plot a graph of the model and save it to a file:

from keras.utils import plot_model
plot_model(model, to_file='model.png')

plot_model takes four optional arguments:

  • show_shapes (defaults to False) controls whether output shapes are shown in the graph.
  • show_layer_names (defaults to True) controls whether layer names are shown in the graph.
  • expand_nested (defaults to False) controls whether to expand nested models into clusters in the graph.
  • dpi (defaults to 96) controls image dpi.

You can also directly obtain the pydot.Graph object and render it yourself, for example to show it in an ipython notebook :

from IPython.display import SVG
from keras.utils import model_to_dot

SVG(model_to_dot(model).create(prog='dot', format='svg'))

 这个过程最大的麻烦在于如何安装pydot和graphviz两个包。

下面对踩过的坑进行总结,以飨大家。

 

1、pydot安装

首先进入anaconda(administrator mode), as shown below:

关于keras 中输出模型网络图的说明-how to plot the lstm model graph and save it to a file_第1张图片

then, you can input the following command,

2、graphviz

Next , you should install the graphviz package, also it is easy to do this,

 

运行一下文档中的代码,如果出现以下错误,

ImportError: Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.

这个错误按理不应该出现,但是找了好久也没有解决,最终一个stack overflow上的post帮了大忙,原来就是package的环境变量没有设置正确。

3、solution

正确方式应该在程序代码中加入如下两行:

import os #
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz-2.38//release/bin/'#注意该目录是你下载graphviz之后解压到这个文件夹的位置, you can customize it as your own needs 。

4、results

We have the following code:

from keras.utils import plot_model
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz-2.38//release/bin/'
from math import sin
from math import pi
from math import exp
from random import random
from random import randint
from random import uniform
from numpy import array
from matplotlib import pyplot
from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
def generate_sequence(length, period, decay):
   return [0.5 + 0.5 * sin(2 * pi * i / period) * exp(-decay * i) for i in range(length)]
# generate input and output pairs of damped sine waves
def generate_examples(length, n_patterns, output):
  X, y = list(), list()
  for _ in range(n_patterns):
   p = randint(10, 20)
   d = uniform(0.01, 0.1)
   sequence = generate_sequence(length + output, p, d)
   X.append(sequence[:-output])
   y.append(sequence[-output:])
  X = array(X).reshape(n_patterns, length, 1)
  y = array(y).reshape(n_patterns, output)
  return X, y

# configure problem
length = 50
output = 5

# define model
model = Sequential()
model.add(LSTM(20, return_sequences=True, input_shape=(length, 1)))
model.add(LSTM(20))
model.add(Dense(output))
model.compile(loss= 'mae' , optimizer= 'adam' )
print(model.summary())
plot_model(model, to_file='model3.png', show_shapes=True)

finally, it produced a pic displayed like this 

关于keras 中输出模型网络图的说明-how to plot the lstm model graph and save it to a file_第2张图片

There  You go!!!

 

 

from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
SVG(model_to_dot(model).create(prog='dot', format='svg'))

你可能感兴趣的:(关于keras 中输出模型网络图的说明-how to plot the lstm model graph and save it to a file)