1. keras model conver from h5 to pb
from keras.models import load_model
import tensorflow as tf
import os
from keras import backend as K
#路径参数
weight_file_path = 'model.h5'
output_graph_name = 'model.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
if os.path.exists(output_dir) == False:
os.mkdir(output_dir)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i],out_prefix + str(i + 1))
sess = K.get_session()
from tensorflow.python.framework import graph_util,graph_io
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
if log_tensorboard:
from tensorflow.python.tools import import_pb_to_tensorboard
import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir,model_name),output_dir)
#输出路径
output_dir = 'model_tf_pb'
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')
2. tf.contrib.predictor
predictor.from_saved_model(model_dir, config = config, signature_def_key='predict')
其中有一个参数tags,因此直接用h5 -> pb并没有这个tags(参考:https://blog.csdn.net/weixin_43215867/article/details/85038313)
因此需要tensorflow load pb ,然后添加tags然后再保存
def load_pb(pb_file_path, url):
sess = tf.Session()
with gfile.FastGFile(pb_file_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
#输入
input = sess.graph.get_tensor_by_name('input_1:0')
#输出
output = sess.graph.get_tensor_by_name('dense_2/Softmax:0')
datas = read_image(url)
#预测结果
ret = sess.run(output, {input:datas})
idx = np.argmax(ret[0])
print(labels[idx], ret)
builder = tf.saved_model.builder.SavedModelBuilder('./model_tags')
signature = predict_signature_def(inputs={'input':input}, outputs={'output':output})
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING],signature_def_map={'predict': signature})
builder.save()
print('save meta graph to model_tags')
def estimator_predictor(model_dir, url):
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.3
predict_fn = predictor.from_saved_model(model_dir, config = config, signature_def_key='predict')
datas = read_image(url)
predictions = predict_fn({'input' : datas})
print(predictions)
其中上面:predict_fn = predictor.from_saved_model(model_dir, config = config, signature_def_key='predict')
这一行很关键否则会报错:expected: set() 为空
ValueError: Got unexpected keys in input_dict: {'input:0'}
expected: set()
根据pb文件查看网络中的变量及tensor
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
output_graph_path = "model/test.pb"
with tf.Session() as sess:
tf.global_variables_initializer().run()
output_graph_def = tf.GraphDef()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
constant_values = {}
constant_ops = [op for op in sess.graph.get_operations()] # if op.type == "Const"
for constant_op in constant_ops:
print constant_op.name
#constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
输出出下:
image_tensor
map/Shape
map/strided_slice/stack
map/strided_slice/stack_1
map/strided_slice/stack_2
map/strided_slice
map/TensorArray
map/TensorArrayUnstack/Shape
map/TensorArrayUnstack/strided_slice/stack
........
resnet_v2_50/logits/weights/read
resnet_v2_50/logits/biases
resnet_v2_50/logits/biases/read
resnet_v2_50/logits/Conv2D
resnet_v2_50/logits/BiasAdd
Squeeze
Softmax
score
sess = tf.Session()
with tf.gfile.FastGFile('model/model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
constant_values = {}
constant_ops = [op for op in sess.graph.get_operations()]
for constant_op in constant_ops:
print(constant_op.name)
#constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
tf.saved_model.simple_save(
sess,
"./save_model_file/",
inputs={"image": sess.graph.get_tensor_by_name('image_tensor:0')},
outputs={"scores": sess.graph.get_tensor_by_name('score:0')}
)
用些方法的时候,自动添加 tf.saved_model.tag_constants.SERVING,
这样就不会报错:RuntimeError: MetaGraphDef associated with tags 'serve' could not be found in SavedModel
从中可以看到,输入tensor是image_tensor,输出是score
其它文章:
https://blog.csdn.net/weixin_43215867/article/details/85038313 解析了tags问题
SavedModelBuilder 类的对象,使用tf.saved_model.builder.SavedModelBuilder方法,该方法的参数是传入用于保存模型的目录名,目录不用预先创建。
然后 生成签名,签名是一组与图有关的输入和输出。使用predict_signature_def方法,传入的参数为输入和输出以及他们的name。
接着传入graph(图)和Variables(变量)给add_meta_graph_and_variables方法。
第一个参数传入的是Session它包含了当前graph(图)和Variables(变量)。
第二个参数是给当前需要保存的MetaGraph 一个标签,标签名可以自定义,在之后载入模型的时候,需要根据这个标签名去查找对应的MetaGraphDef,找不到就会报如 RuntimeError: MetaGraphDef associated with tags ‘foo’ could not be found in SavedModel这样的错。
标签也可以选用系统定义好的参数,
tf.saved_model.tag_constants.SERVING与
tf.saved_model.tag_constants.TRAINING等。
https://github.com/amir-abdi/keras_to_tensorflow
https://github.com/tensorflow/tensorflow/issues/35806