TF-Lite极简参考-模型转换

TF-Lite极简参考-模型转换

《TF-Lite极简参考-模型转换》

  TensorFlow Lite 可以很方便的把基于TensorFlow训练的模型进行转换,然后推理,在TensorFlow2.0中,keras被全面整合,可以使用tf.keras来更高效的构建模型,尽管前几天爆出TensorFlow2.0惊现大bug,并且一直被吐槽难用,但是受众依然很广,如果不用太多自定义的层,还是很稳定的。我大概是从TensorFlow 0.10版本开始用的,追了很久,也成功在服务端落地过很多OCR项目。

Key Words:TF lite、模型转换


Beijing, 2020

作者:RaySue

Code:https://github.com/RaySue/TF-Lite-Demo.git

Agile Pioneer  

TensorFlow Lite 架构

TF-Lite极简参考-模型转换_第1张图片
TF-Lite 整体架构

  部署TensorFlow Lite模型文件使用:

  • Java API:围绕Android上C++ API的便捷包装。
  • C++ API:加载TensorFlow Lite模型文件并调用解释器。 Android和iOS都提供相同的库。
  • 解释器:使用一组内核来执行模型,解释器支持选择性内核加载。没有内核,只有100KB;加载了所有内核,300KB。这比TensorFlow Mobile要求的1.5M的显著减少。
  • 在选定的Android设备上,解释器将使用Android神经网络API进行硬件加速,如果没有可用的,则默认为CPU执行。
TF-Lite极简参考-模型转换_第2张图片
TF-Lite 使用流程

TensorFlow Lite 提供以下三种模型转换方法:

  • tf.lite.TFLiteConverter.from_keras_model(),转换实例化的Keras模型(tf2.0)
  • tf.lite.TFLiteConverter.from_saved_model(),转换pb文件(TensorFlow)
  • tf.lite.TFLiteConverter.from_concrete_functions(),转换具体的函数(函数转为图)

Convert ckpt to pb


def convert_to_pb(model, path, input_layer_name, output_layer_name, pbfilename, verbose=False):
    model.load(path, weights_only=True)
    print("[INFO] Loaded CNN network weights from " + path + " ...")

    print("[INFO] Re-export model ...")
    del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
    model.save("model-tmp.tfl")

    # taken from: https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow

    print("[INFO] Re-import model ...")

    input_checkpoint = "model-tmp.tfl"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', True)
    sess = tf.Session();
    saver.restore(sess, input_checkpoint)

    # print out all layers to find name of output

    if (verbose):
        op = sess.graph.get_operations()
        [print(m.values()) for m in op][1]

    print("[INFO] Freeze model to " + pbfilename + " ...")

    # freeze and removes nodes which are not related to feedforward prediction

    minimal_graph = convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_layer_name])

    graph_def = optimize_for_inference_lib.optimize_for_inference(minimal_graph, [input_layer_name],
                                                                  [output_layer_name], tf.float32.as_datatype_enum)
    graph_def = TransformGraph(graph_def, [input_layer_name], [output_layer_name], ["sort_by_execution_order"])

    with tf.gfile.GFile(pbfilename, 'wb') as f:
        f.write(graph_def.SerializeToString())

    # write model to logs dir so we can visualize it as:
    # tensorboard --logdir="logs"

    if (verbose):
        writer = tf.summary.FileWriter('logs', graph_def)
        writer.close()

    # tidy up tmp files

    for f in glob.glob("model-tmp.tfl*"):
        os.remove(f)

    os.remove('checkpoint')

Convert pb to tflite

def convert_to_tflite(pbfilename, input_layer_name, output_layer_name,
                      input_tensor_dim_x, input_tensor_dim_y, input_tensor_channels=3):
    input_tensor = {input_layer_name: [1, input_tensor_dim_x, input_tensor_dim_y, input_tensor_channels]}

    print("[INFO] tflite model to " + pbfilename.replace(".pb", ".tflite") + " ...")

    converter = tf.lite.TFLiteConverter.from_frozen_graph(pbfilename, [input_layer_name], [output_layer_name],
                                                          input_tensor)
    tflite_model = converter.convert()
    open(pbfilename.replace(".pb", ".tflite"), "wb").write(tflite_model)

Convert keras model to tflite

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open('/home/ai/converted_model.tflite', 'wb').write(tflite_model)

Validation

tflite inference


tflite_model_path = "xxx.tflite"
tflife_model = tf.lite.Interpreter(model_path=tflite_model_path)
tflife_model.allocate_tensors()

# Get input and output tensors.
tflife_input_details = tflife_model.get_input_details()
tflife_output_details = tflife_model.get_output_details()

frame = cv2.imread("xxx.jpg")
small_frame = cv2.resize(frame, (224, 224), cv2.INTER_AREA)
small_frame = np.expand_dims(small_frame, 0)

tflife_input_data = np.reshape(np.float32(small_frame), (1, 224, 224, 3))
tflife_model.set_tensor(tflife_input_details[0]['index'], tflife_input_data)

tflife_model.invoke()

output_tflite = tflife_model.get_tensor(tflife_output_details[0]['index'])

你可能感兴趣的:(模型部署,tensorflow,keras,人工智能)