TensorFlow Lite(Keras的.H5模型)

1.使用tensorflow提供的API完成转换

  • 转换前模型:keras的.h5模型
  • 转换后模型:.tflite模型
  • 相关环境:
    • windows10
    • numpy 1.19.5
    • tensorflow 2.4.1

转换代码如下:

import tensorflow as tf

model=load_model("../models/emotion.h5")  # 加载h5模型

converter =tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

open('emotion75agsadg.tflite', 'wb').write(tflite_model)  #保存转换后的模型

如果转换成功,则返回保存文件的字节数,我的如下:

[]: 11149640

2. 现在来测试一下tflite模型

2.1 导库和加载模型

import tensorflow as tf
import cv2
import numpy as np

tflife_model = tf.lite.Interpreter(model_path="../models/emotion75.tflite")  #加载tflite模型
tflife_model.allocate_tensors() # 创建tensors

2.2 获取输入和输出节点的信息

从下面代码我们可以获取模型的输入输出节点的细节,包括节点序号、维度、数据类型等等,方便后面数据预处理和传值时位置

tflife_input_details = tflife_model.get_input_details()  #获取输入的细节信息,list类型,可打印查看
tflife_output_details = tflife_model.get_output_details()

输入节点:
TensorFlow Lite(Keras的.H5模型)_第1张图片
输出节点:
TensorFlow Lite(Keras的.H5模型)_第2张图片

2.3 数据处理

我的模型从2.2看到输入节点信息:

  • ‘shape’: array([ 1, 48, 48, 1]),
  • ‘dtype’: numpy.float32

故我们要对图像进行相应的处理
我的测试原图:
TensorFlow Lite(Keras的.H5模型)_第3张图片

frame = cv2.imread("oneface.jpg",cv2.IMREAD_GRAYSCALE) # 以灰度模式读图
small_frame = cv2.resize(frame, (48, 48), cv2.INTER_AREA) # 改变尺寸
small_frame = np.expand_dims(small_frame, 0)	# 扩展dim=0维度
small_frame = np.expand_dims(small_frame, 3)	# 扩展dim=3维度
tflife_input_data = np.float32(small_frame)		# 类型转为float32

2.4 传入、推理、得到结果

#从节点信息看,输入节点编号为0,通过tflife_input_details[0]['index']得到节点编号,将tflife_input_data数据填喂给该节点
tflife_model.set_tensor(tflife_input_details[0]['index'], tflife_input_data)  

tflife_model.invoke()  # 运行推理

# 同理获取输出节点的值,即我们的结果,返回值是array类型
output_tflite = tflife_model.get_tensor(tflife_output_details[0]['index']) 

我的是表情分类,故结果是7维:
在这里插入图片描述

2.5 可视化

import matplotlib.pyplot as plt
label=['anger','disgust','fear','happy','sad','surprised','normal']  # 定义横坐标
plt.bar(label,output_tflite[0])

TensorFlow Lite(Keras的.H5模型)_第4张图片

你可能感兴趣的:(keras)