ckpt为savermodel模型并TRT调用的问题

因为任务需要,需要将TensorFlow训练的得到的ckpt文件转换成savemodel,并用TRT调用,这里记下自己从中遇到的坑。

1、ckpt转savemodel,可以参考https://blog.csdn.net/zmlovelx/article/details/100511406这个博客,写的很详细,按照使用没有任何问题。保存好以后到save目录下,会有一个saved_model.pb文件以及variables文件夹。其中variables保存所有变量,saved_model.pb用于保存模型结构等信息。;
2、TRT下调用savemodel(这里不赘述TRT安装,后续如果需要可以再写一篇),官网给出了如下样例代码:
If you have a SavedModel representation of your TensorFlow model, you can create a TensorRT inference graph directly from your SavedModel, for example:

import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt

converter = trt.TrtGraphConverter(input_saved_model_dir=input_saved_model_dir)
converter.convert()
converter.save(output_saved_model_dir)

with tf.Session() as sess:
    # First load the SavedModel into the session    
    tf.saved_model.loader.load(
        sess, [tf.saved_model.tag_constants.SERVING], output_saved_model_dir)
    output = sess.run([output_tensor], feed_dict={input_tensor: input_data})

这里找到自己之前savemodel的存放路径就可以,也就是output_saved_model_dir,但是在调用中发现有了错误“RuntimeError: MetaGraphDef associated with tags ‘serve’ could not be found in SavedModel. To inspect available tag-sets in the1/ SavedModel, please use the SavedModel CLI: saved_model_cli
available_tags: [{‘serve’, ‘train’}]”,一开始百思不得其解,后来查找写资料之后慢慢了解到其中问题;
3、这里想深入了解的可以参考这个博客。我就针对我自己的问题说明以下,转换成savemodel过程中的builder.add_meta_graph_and_variables 中的add_meta_graph_and_variables函数其第一个参数传入当前的session,包含了graph的结构与所有变量;而第二个参数是给当前需要保存的meta graph一个标签,标签名可以自定义,在之后载入模型的时候,需要根据这个标签名去查找对应的MetaGraphDef;
而tf.saved_model.loader.load此函数第一个参数就是当前的session,第二个参数是在保存的时候定义的meta graph的标签,标签一致才能找到对应的meta graph。第三个参数就是模型保存的目录。所以报错原因就是处在保存的savemodel与要载入的不一致,因此,保持与官方一致做法;
4、builder.add_meta_graph_and_variables(sess, [tf.saved_model.TRAINING, tf.saved_model.SERVING], strip_default_attrs=True)改成:

builder.add_meta_graph_and_variables(sess, tf.saved_model.SERVING], strip_default_attrs=True)
5、重新操作一遍,问题解决。

你可能感兴趣的:(savemodel调用)