yolov3实战 超简单上手 飞机与油桶数据集之tflite预测

Tensorflow Lite

Tensorflow Lite(tf lite) 针对移动设备(安卓、ios)和嵌入式设备的轻量化解决方案,占用空间小,低延迟。tf lite在android8.1以上的设备上可以通过ANNA启用硬件加速。

tf lite 主要流程:

yolov3实战 超简单上手 飞机与油桶数据集之tflite预测_第1张图片

加载、转换模型

在前几篇 我用yolo v3 训练了一个keras模型,本次操作用这个keras模型。注意:之前的操作是只保存了权重,但是在使用tf lite转化模型,被转换的模型需要有完整的模型结构和权重参数。

import tensorflow as tf

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
# keras模型  
import tensorflow as tf

# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
# 在前几篇训练的yolo v3模型,对它进行转换
def convert2tflite(model_weightandstruct_path, save_tflite_path):
 '''
 将keras模型文件转化为tflite
 :param model_weightandstruct_path: 带有权重和模型结构的模型路径   save_weight_only = True 需要加载模型、权重再进行保存
 :param save_tflite_path: 保存tflite路径
 '''
 # 如果你的模型是save_weight_only,只保存权重,没有保存模型结构,需要先加载模型结构,再加载模型权重,重新用.save保存。如果只有模型权重,没有结构,则会报错 -》 "No model found in config file."
 # yolo_model = yolo_body(Input(shape=(None, None, 3)), 9 // 3, 2)
 # yolo_model.load_weights(keras_model_path)
 # yolo_model.save(model_weightandstruct_path)

 # 输入变量 - 数据的数量/图像大小/通道(保持跟训练模型时一致就好),如果你的tf lite模型打印output_details与预期的不对。则需要在这里加入input_shapes
 input_shapes = {"input_1": [1, 416, 416, 3]}  # 输入必须为32的倍数,原因是yolov3 会对模型进行8,16,32共3次降采样
 converter = TFLiteConverter.from_keras_model_file(model_weightandstruct_path, input_shapes=input_shapes)
 #------------------------
 #	float16
 #------------------------
 # converter.optimizations = [tf.lite.Optimize.DEFAULT]
 # converter.target_spec.supported_types = [tf.compat.v1.lite.constants.FLOAT16]
 
 #------------------------
 #	int8
 #------------------------
 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
 converter.optimizations = [tf.lite.Optimize.DEFAULT]
 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
 converter.allow_custom_ops = True
 converter.experimental_enable_resource_variables = True
 tflite_quant_model = converter.convert()
 open(save_tflite_path, "wb").write(tflite_quant_model)
# 之后没有报错即转换成功
# int8 与 float16 大小有明显的区别。具体的速度与性能消耗等之后有嵌入式端与安卓端部署再说

在这里插入图片描述

模型输入输出信息打印

def get_tflite_message(save_tflite_path):
    # Load TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path=save_tflite_path)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    print(input_details)
    print(output_details)
def tf_lite_predict(save_tflite_path, ):
    '''
    yolo_v3 tflite 保存的输入输出信息
    [{'name': 'input_1', 'index': 0, 'shape': array([  1, 416, 416,   3], dtype=int32), 'shape_signature': array([  1, 416, 416,   3], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
    [{'name': 'Identity', 'index': 298, 'shape': array([ 1, 13, 13, 21], dtype=int32), 'shape_signature': array([ 1, 13, 13, 21], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'Identity_1', 'index': 315, 'shape': array([ 1, 26, 26, 21], dtype=int32), 'shape_signature': array([ 1, 26, 26, 21], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'Identity_2', 'index': 332, 'shape': array([ 1, 52, 52, 21], dtype=int32), 'shape_signature': array([ 1, 52, 52, 21], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
    '''
    # 加载模型
    predictor = tf.lite.Interpreter(model_path=save_tflite_path)
    predictor.allocate_tensors()
    nput_details = predictor.get_input_details()
    # 模型预测
    predictor.set_tensor(input_details[0]["index"], image_tensor)  # 传入的数据必须为ndarray类型
    predictor.invoke()
    output_details = predictor.get_output_details()
	# preds即我们tflite模型的输出
    preds = [tf.convert_to_tensor(predictor.get_tensor(output_details[i]['index'])) for i in range(len(output_details))]

Int8 模型结果

yolov3实战 超简单上手 飞机与油桶数据集之tflite预测_第2张图片

float16 模型结果

yolov3实战 超简单上手 飞机与油桶数据集之tflite预测_第3张图片
从结果上看 两者并没有差别,性能未测

你可能感兴趣的:(#,目标检测,深度学习,tensorflow,keras,神经网络,目标跟踪)