在利用tensorflow训练时,通常我们会利用tf.train.Saver()函数将训练得到的模型保存成pd或者cpkt模型,如下图所示。但是当模型需要在C++程序或者其他平台调用时,通常需要将其冻结成pb模型。
关于如何将pd或cpkt模型冻结成pb模型的代码,已经有许多的博文介绍,但是在使用时总是会出一些问题。本人在冻结过程中也出现了一些问题,现在将如何冻结做一个总结。pd或cpkt模型冻结成pb模型的代码是一样的,就输入名称不同而已,通用的代码如下:
import tensorflow as tf
meta_path = 'model/Module.pd.meta' # Your .pd.meta file
#meta_path = 'model/Module.cpkt.meta' # Your .cpkt.meta file
output_node_names = ["output/Sigmoid"] # Output nodes
with tf.Session() as sess:
# Restore the graph
saver = tf.train.import_meta_graph(meta_path)
# Load weights
saver.restore(sess,tf.train.latest_checkpoint('model/'))
# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names)
# Save the frozen graph
with open('Module.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
代码比较简单的,主要就是确定模型的位置路径(meta_path)和最后一层网络输出的节点名称(output_node_names)。但是往往容易在输出节点名称的确定中出问题。节点名称错误,会有如下报错提示:
因此,需要准确地确定模型的最后一层网络输出节点名称。这里需要注意的是:当你设置的名称不是模型最后一层输出节点名称而是其他节点的名称时,冻结代码是不会报错的,但最后冻结的pb模型是不成功的,只有1k大小,如下:
查看网络的输出节点总共有3种方法,包括(1)查看代码确定;(2)通过代码输出相应的tensor名称查找;(3)利用tensorboard查看图结构确定节点名称。这里做一些阐述。参考文章:https://blog.csdn.net/weixin_43815222/article/details/108094512?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_title-0&spm=1001.2101.3001.4242.
(1)查看代码确定
这种方法就是根据自己写的代码所设定的输出节点名称来确定的。但是,如果自己在代码中忘记设计节点名称,网络就会使用默认的名称,这时候就不好用这种方法。还有一种情况是自己填上了设置的名称还是报错,这种情况是没有考虑名称的嵌套。如这文章https://www.jianshu.com/p/06548e3e8f4b所介绍的一个例子代码如下:
with tf.name_scope("score"):
# 全连接层,后面接dropout以及relu激活
fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
fc = tf.contrib.layers.dropout(fc, self.keep_prob)
fc = tf.nn.relu(fc)
# 分类器
self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1,name='output') #预测类别
分类器设置的输出节点名称为“output”,但是因为前面with tf.name_scope函数中设置了名称为“score”,则它的输出节点名称应该为“score/ output”,类似于路径一样。因此,在查看代码节点名称时应该注意此情况。
(2)通过代码输出相应的tensor名称查找
通过下面的代码,读取pd或cpkt模型可以将模型中所以节点的tensor名称输出出来,以供查找。
#查看输出节点代码
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join('model/Module.pd')
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print('tensor_name: ', key)
尝试了下,输出结果如下图,可以输出所有tensor名称,然后再一个个查找。但是当网络比较复杂,设置节点较多时,这种方法并不是太适用。不过此代码可以作为检查模型tensor名称或者获取所有模型tensor名称时使用。
参考文章: https://blog.csdn.net/helei001/article/details/56489658
(3)利用tensorboard查看图结构确定节点名称
利用tensorboard可以查看所训练模型的图结构,通过这个方法来确定最后一层网络输出节点名称,是最好的方法。我是在anaconda中搭建tensorflow的,因此可以通过anaconda来启动tensorboard。这里是widows版本的anaconda+tensorflow。首先在anaconda中切换到Environments模式,选择你所安装的tensorflow,单击三角号打开Open Terminal,可以打开tensorflow的终端。
在终端中输入tensorboard --logdir=模型路径,如下图所示
按回车运行,可以获得打开tensorboard的路径:http://DESKTOP-QOOI0L9:6006/。
复制该路径放入浏览器中,即可打开tensorboard。如果使用其他浏览器不能打开或者乱码,建议使用谷歌浏览器。如果谷歌浏览器出现无法访问此网站,可以断开网络再打开。打开tensorboard,切换到GRAPHS,可以看到模型的图结构,如下:
要注意,在图中如果是椭圆形框,则表示的是可以识别的tensor,如果是矩形框,则表示的是组合操作,要继续点击到里面去查找。通常最后一层网络输出节点是在“output”中,这是个矩形框,点击进去如下图所示,找到最后输出的椭圆形框。由于我代码中未设置名称,因此这里是默认的“Sigmoid”。但这里也需要注意嵌套关系,是从output组合中找到的,因此输出节点类似于第一种的情况,为“output/Sigmoid”。
将该最后一层网络输出节点名称写入前面代码中,运行后即可以将pd或cpkt模型冻结成pb模型,如下图:
参考文章:https://blog.csdn.net/weixin_43815222/article/details/108075416