TFlite 的简单使用

背景

TensorFlow Lite 转换器可根据输入的 TensorFlow 模型生成 TensorFlow Lite 模型(一种优化的 FlatBuffer 格式,以 .tflite 为文件扩展名). 作用是进一步缩短模型延迟时间和减小模型大小,同时最大限度降低准确率损失和添加元数据,从而在设备上部署模型时可以更轻松地创建平台专用封装容器代码。

TFlite 的简单使用_第1张图片

环境

tensorflow=2.4.1

实践例子

把Tensorflow的模型转换成tflite

import tensorflow as tf
def convert_to_tflite(model):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tfmodel = converter.convert()
    file = open('yourmodel.tflite', 'wb')
    file.write(tfmodel)
    file.close()

运行Tflite模型

import tensorflow as tf
def run_reference_by_tflite(input):
    interpreter = tf.lite.Interpreter(model_path="yourmodel.tflite")
    interpreter.allocate_tensors()
    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    # input details
    print(input_details, len(input_details))

    # output details
    print(output_details)

    # input_details[0]['index'] = the index which accepts the input
    interpreter.set_tensor(input_details[0]['index'], input)


    # run the inference
    interpreter.invoke()

    # output_details[0]['index'] = the index which provides the input
    output_data = interpreter.get_tensor(output_details[0]['index'])

    print('interpreter: ', output_data)

不同版本的tensorflow或不同的格式的模型对应的转换方法

  • 使用 tf.lite.TFLiteConverter 转换 TensorFlow 2.x 模型。TensorFlow 2.x 模型是使用 SavedModel 格式存储的,并通过高阶 tf.keras.* API(Keras 模型)或低阶 tf.* API(用于生成具体函数)生成。因此,您有以下三个选项(示例包含在接下来的几节中):

    • tf.lite.TFLiteConverter.from_saved_model()(推荐):转换 SavedModel。

    • tf.lite.TFLiteConverter.from_keras_model():转换 Keras 模型。

    • tf.lite.TFLiteConverter.from_concrete_functions():转换具体函数。

  • 使用 tf.compat.v1.lite.TFLiteConverter 转换 TensorFlow 1.x 模型(示例位于 GitHub 上):

    • tf.compat.v1.lite.TFLiteConverter.from_saved_model():转换 SavedModel。
    • tf.compat.v1.lite.TFLiteConverter.from_keras_model_file():转换 Keras 模型。
    • tf.compat.v1.lite.TFLiteConverter.from_session():从会话转换 GraphDef。
    • tf.compat.v1.lite.TFLiteConverter.from_frozen_graph():从文件转换 Frozen GraphDef。如果您有检查点,请先将其转换为 Frozen GraphDef 文件,然后使用此 API(如此处所示)。


相关错误以及解决方法

1. ValueError: Cannot set tensor: Got value of type NOTYPE but expected type FLOAT32 for input 0,

ValueError: Cannot set tensor: Got value of type INT32 but expected type FLOAT32 for input 0,

ValueError: Cannot set tensor: Got value of type UINT8 but expected type FLOAT32 for input 0, name: input_1 

解决方法:

under_exp = np.array(under_exp, dtype=np.float32)

参考资料

python - How to convert keras(h5) file to a tflite file? - Stack Overflowhttps://www.tensorflow.org/lite/convert

你可能感兴趣的:(tensorflow,python,开发语言)