tf.Keras 保存为pb文件

折腾了我几天,一直搞不定。最后用以下代码成功保存。

方法一:

tensorflow2.0以上版本可以使用

tf.saved_model.save(model, "save_test")
model = tf.saved_model.load("save_test")

来保存成pb文件以及读取,但是保存的是将模型和权重独立。

2020.3.1更新:

下面方法为新的保存方法,可以直接将模型和权重保存为pb文件。

2020.5.6更新:

保存成pb模型必须在程序最开始处调用:

tf.enable_eager_execution()

使其进入eager模式。

    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="",
                      name="frozen_graph.pb",
                      as_text=False)


    with tf.Graph().as_default():
        output_graph_def = tf.compat.v1.GraphDef()

        # 打开.pb模型
        with open("frozen_graph.pb", "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tensors = tf.import_graph_def(output_graph_def, name='')
            # print("tensors:", tensors)


        with tf.compat.v1.Session() as sess:
            op = sess.graph.get_operations()
            for i, m in enumerate(op):
                print('op{}:'.format(i), m.values())


            input_x = sess.graph.get_tensor_by_name("x:0")  #可以看op的首末名input.name
            print("input_X:", input_x)

            out_softmax = sess.graph.get_tensor_by_name(
                "Identity:0") #可以看op的首末名
            print("Output:", out_softmax)

            # 读入图片
            img = cv2.imread("1.jpg")
            img = cv2.resize(img, (128, 128))
            img = img.astype(np.float32)
            # img = 1 - img / 255;
            # img=np.reshape(img,(1,28,28,1))
            print("img data type:", img.dtype)

            img_out_softmax = sess.run(out_softmax,
                                       feed_dict={input_x: np.reshape(img, (1, 128, 128, 3))})
            print("img_out_softmax:", img_out_softmax)
            for i, prob in enumerate(img_out_softmax[0]):
                print('class {} prob:{}'.format(i, prob))
            prediction_labels = np.argmax(img_out_softmax, axis=1)
            print("Final class if:", prediction_labels)
            print("prob of label:", img_out_softmax[0, prediction_labels])

 

方法二

tensorflow1.x版本可以使用如下代码保存:

    session = tf.keras.backend.get_session()
    model_name = 'my_model'
    builder = tf.saved_model.builder.SavedModelBuilder(model_name)
    builder.add_meta_graph_and_variables(session, ["my_model"])
    builder.save()

    model_name = 'my_model'
    with tf.Session(graph=tf.Graph()) as sess:
        tf.saved_model.loader.load(sess, ["my_model"], model_name)

    # with tf.Session() as sess:
    #     init = tf.global_variables_initializer()
    #     sess.run(init)

        op = sess.graph.get_operations()

        # 打印图中有的操作
        for i, m in enumerate(op):
            print('op{}:'.format(i), m.values())

        input_x = sess.graph.get_tensor_by_name("input_1:0")  # 可以看op的首末名
        print("input_X:", input_x)

        out_softmax = sess.graph.get_tensor_by_name(
            "MobileNetV3_Small/LastStage/Squeeze/Squeeze_1:0")  #可以看op的首末名
        print("Output:", out_softmax)

        # 读入图片
        img = cv2.imread("1.jpg")
        img = cv2.resize(img, (128, 128))
        img = img.astype(np.float32)
        # img = 1 - img / 255;
        # img=np.reshape(img,(1,28,28,1))
        print("img data type:", img.dtype)

        img_out_softmax = sess.run(out_softmax,
                                   feed_dict={input_x: np.reshape(img, (1, 128, 128, 3))})
        print("img_out_softmax:", img_out_softmax)
        for i, prob in enumerate(img_out_softmax[0]):
            print('class {} prob:{}'.format(i, prob))
        prediction_labels = np.argmax(img_out_softmax, axis=1)
        print("Final class if:", prediction_labels)
        print("prob of label:", img_out_softmax[0, prediction_labels])

方法三:

还有一种方法可以保存,但是有些模型读取时会出错,具体出错为tensofrlow 1.13下tf.keras.layers.BatchNormalization: ValueError when reading frozen graph出错,https://stackoverflow.com/questions/56418877/tf-keras-layers-batchnormalization-valueerror-when-reading-frozen-graph

目前我无法解决,如果你有方法能够解决,请联系我:

我的解决方法:https://blog.csdn.net/a362682954/article/details/104611325

    def freeze_graph(graph, session, output_node_names, model_name):
        with graph.as_default():
        graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
        graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output_node_names)
        graph_io.write_graph(graphdef_frozen, "", os.path.basename(model_name) + ".pb", as_text=False)

    tf.keras.backend.set_learning_phase(0)  # this line most important
    model_name = 'my_model2.pb'
    session = tf.keras.backend.get_session()
    freeze_graph(session.graph, session, [out.op.name for out in model.outputs], model_name)
def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        # 打开.pb模型
        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(output_graph_def, name='1')
            # print("tensors:", tensors)

        # 在一个session中去run一个前向
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            op = sess.graph.get_operations()

            # 打印图中有的操作
            for i,m in enumerate(op):
                print('op{}:'.format(i),m.values())

            input_x = sess.graph.get_tensor_by_name("input_1:0")  #可以看op的首末名
            print("input_X:", input_x)

            out_softmax = sess.graph.get_tensor_by_name("MobileNetV3_Small/LastStage/Squeeze/Squeeze_1:0")  #可以看op的首末名
            print("Output:",out_softmax)

            # 读入图片
            img = cv2.imread(jpg_path, 0)
            img=cv2.resize(img,(128,128,3))
            img=img.astype(np.float32)
            img=1-img/255;
            # img=np.reshape(img,(1,28,28,1))
            print("img data type:",img.dtype)


            img_out_softmax = sess.run(out_softmax,
                                       feed_dict={input_x: np.reshape(img,(1,128,128,3))})

            print("img_out_softmax:", img_out_softmax)
            for i,prob in enumerate(img_out_softmax[0]):
                print('class {} prob:{}'.format(i,prob))
            prediction_labels = np.argmax(img_out_softmax, axis=1)
            print("Final class if:",prediction_labels)
            print("prob of label:",img_out_softmax[0,prediction_labels])

 

 

参考文献:

https://zhuanlan.zhihu.com/p/55600911

https://blog.csdn.net/qq_25109263/article/details/81285952

https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/

https://github.com/leimao/Frozen_Graph_TensorFlow

 

你可能感兴趣的:(深度学习,keras)