折腾了我几天,一直搞不定。最后用以下代码成功保存。
tf.saved_model.save(model, "save_test")
model = tf.saved_model.load("save_test")
来保存成pb文件以及读取,但是保存的是将模型和权重独立。
2020.3.1更新:
下面方法为新的保存方法,可以直接将模型和权重保存为pb文件。
2020.5.6更新:
保存成pb模型必须在程序最开始处调用:
tf.enable_eager_execution()
使其进入eager模式。
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="",
name="frozen_graph.pb",
as_text=False)
with tf.Graph().as_default():
output_graph_def = tf.compat.v1.GraphDef()
# 打开.pb模型
with open("frozen_graph.pb", "rb") as f:
output_graph_def.ParseFromString(f.read())
tensors = tf.import_graph_def(output_graph_def, name='')
# print("tensors:", tensors)
with tf.compat.v1.Session() as sess:
op = sess.graph.get_operations()
for i, m in enumerate(op):
print('op{}:'.format(i), m.values())
input_x = sess.graph.get_tensor_by_name("x:0") #可以看op的首末名input.name
print("input_X:", input_x)
out_softmax = sess.graph.get_tensor_by_name(
"Identity:0") #可以看op的首末名
print("Output:", out_softmax)
# 读入图片
img = cv2.imread("1.jpg")
img = cv2.resize(img, (128, 128))
img = img.astype(np.float32)
# img = 1 - img / 255;
# img=np.reshape(img,(1,28,28,1))
print("img data type:", img.dtype)
img_out_softmax = sess.run(out_softmax,
feed_dict={input_x: np.reshape(img, (1, 128, 128, 3))})
print("img_out_softmax:", img_out_softmax)
for i, prob in enumerate(img_out_softmax[0]):
print('class {} prob:{}'.format(i, prob))
prediction_labels = np.argmax(img_out_softmax, axis=1)
print("Final class if:", prediction_labels)
print("prob of label:", img_out_softmax[0, prediction_labels])
session = tf.keras.backend.get_session()
model_name = 'my_model'
builder = tf.saved_model.builder.SavedModelBuilder(model_name)
builder.add_meta_graph_and_variables(session, ["my_model"])
builder.save()
model_name = 'my_model'
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["my_model"], model_name)
# with tf.Session() as sess:
# init = tf.global_variables_initializer()
# sess.run(init)
op = sess.graph.get_operations()
# 打印图中有的操作
for i, m in enumerate(op):
print('op{}:'.format(i), m.values())
input_x = sess.graph.get_tensor_by_name("input_1:0") # 可以看op的首末名
print("input_X:", input_x)
out_softmax = sess.graph.get_tensor_by_name(
"MobileNetV3_Small/LastStage/Squeeze/Squeeze_1:0") #可以看op的首末名
print("Output:", out_softmax)
# 读入图片
img = cv2.imread("1.jpg")
img = cv2.resize(img, (128, 128))
img = img.astype(np.float32)
# img = 1 - img / 255;
# img=np.reshape(img,(1,28,28,1))
print("img data type:", img.dtype)
img_out_softmax = sess.run(out_softmax,
feed_dict={input_x: np.reshape(img, (1, 128, 128, 3))})
print("img_out_softmax:", img_out_softmax)
for i, prob in enumerate(img_out_softmax[0]):
print('class {} prob:{}'.format(i, prob))
prediction_labels = np.argmax(img_out_softmax, axis=1)
print("Final class if:", prediction_labels)
print("prob of label:", img_out_softmax[0, prediction_labels])
我的解决方法:https://blog.csdn.net/a362682954/article/details/104611325
def freeze_graph(graph, session, output_node_names, model_name):
with graph.as_default():
graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output_node_names)
graph_io.write_graph(graphdef_frozen, "", os.path.basename(model_name) + ".pb", as_text=False)
tf.keras.backend.set_learning_phase(0) # this line most important
model_name = 'my_model2.pb'
session = tf.keras.backend.get_session()
freeze_graph(session.graph, session, [out.op.name for out in model.outputs], model_name)
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
# 打开.pb模型
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(output_graph_def, name='1')
# print("tensors:", tensors)
# 在一个session中去run一个前向
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
op = sess.graph.get_operations()
# 打印图中有的操作
for i,m in enumerate(op):
print('op{}:'.format(i),m.values())
input_x = sess.graph.get_tensor_by_name("input_1:0") #可以看op的首末名
print("input_X:", input_x)
out_softmax = sess.graph.get_tensor_by_name("MobileNetV3_Small/LastStage/Squeeze/Squeeze_1:0") #可以看op的首末名
print("Output:",out_softmax)
# 读入图片
img = cv2.imread(jpg_path, 0)
img=cv2.resize(img,(128,128,3))
img=img.astype(np.float32)
img=1-img/255;
# img=np.reshape(img,(1,28,28,1))
print("img data type:",img.dtype)
img_out_softmax = sess.run(out_softmax,
feed_dict={input_x: np.reshape(img,(1,128,128,3))})
print("img_out_softmax:", img_out_softmax)
for i,prob in enumerate(img_out_softmax[0]):
print('class {} prob:{}'.format(i,prob))
prediction_labels = np.argmax(img_out_softmax, axis=1)
print("Final class if:",prediction_labels)
print("prob of label:",img_out_softmax[0,prediction_labels])
参考文献:
https://zhuanlan.zhihu.com/p/55600911
https://blog.csdn.net/qq_25109263/article/details/81285952
https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/
https://github.com/leimao/Frozen_Graph_TensorFlow