tf_lite的配置

注:  别看网上教程,各种坑 ,其实tensorflow在很早的2.x版本就提供了lite接口了,用网上的c++编译版本光环境配置就要很久!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1

 

 

两种方式将keras   .h5模型转化为tflite模型

1)先将.h5模型转化为 tf的 pb模型    ,再用tf.contrib.lite.TFLiteConverter.from_frozen_graph 将Pb转化为 tflite模型

2)直接使用tf.contrib.lite.kerastotflite接口进行转化

 

这里,我成功了第一种方法:

#-*- coding: utf-8 -*-
import tensorflow as tf
from keras import backend as K


gpu_list = ""
config = tf.ConfigProto( \
           allow_soft_placement=True, \
           log_device_placement=False, \
           gpu_options=tf.GPUOptions(allow_growth=True, visible_device_list=gpu_list), \
           intra_op_parallelism_threads=40, \
           inter_op_parallelism_threads=40, \
           device_count = {'GPU': 0, 'CPU': 40})
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
K.set_session(sess)

#------------------配置设备----------------------------------------

from keras.models import load_model
import os

modelnames = os.listdir("./models")
#modelnames = ["east_model_3T512_tiny.h5", "idcard_chinese_big_0516.h5", "idcard_corner.h5", "idcard_gender.h5", "idcard_nation.h5", "idcard-numberx-withgangen0315.h5"]
modelnames = ["east_model_3T512_tiny.h5", "idcard_corner.h5", "idcard_numberx_nornn.h5", "idcard_chinese_nornn.h5"]

for m in modelnames:
    print(m)
    modelname = m.split('.')[0]

    model_path = os.path.join("./models/", m)
    model = load_model(model_path, compile=False)

    #inputs = [node for node in model.inputs]
    #input_names = [node.op.name for node in model.inputs]
    #outputs = [node for node in model.outputs]
    #output_names = [node.op.name for node in model.outputs]
    # -------- DEBUG ---------
    #lys = [node.name for node in model.layers]
    #print(lys)
    # -------- DEBUG ---------
    input_name = "the_input"
    out_name = ""

#输入和输出在.h5模型中的名称
    #if m in ["idcard_nation.h5", "idcard-numberx-withgangen0315.h5", "idcard_gender.h5"]:
    if m in ["idcard_nation.h5", "idcard_numberx_nornn.h5"]:
        out_name = "softmax"
    #elif m in ["idcard_chinese_big_0516.h5"]:
    elif m in ["idcard_chinese_nornn.h5"]:
        out_name = "blstm2_out"
    elif m in ["idcard_corner.h5"]:
        out_name = "y_pred"
    elif m in ["east_model_3T512_tiny.h5"]:
        input_name = "input_img"
        out_name = "east_detect"
    input_names = [model.get_layer(name=input_name).output.op.name]  #转化为.pb中的名称
    output_names = [model.get_layer(name=out_name).output.op.name]

    frozen_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names)
    tf.train.write_graph(frozen_def, ".", "./pb/"+modelname+".pb", as_text=False)

    for i,n in enumerate(sess.graph_def.node):   #==============由于.h5模型和.pb模型op名称不一定一致===========================
        print("Name of the node - %s" % n.name)

    print('input_names',input_names)
    print('output_names',output_names)

    converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph("./pb/"+modelname+".pb", input_names, output_names)
    converter.allow_custom_ops = True
    converter.post_training_quantize = True
    converter.target_ops = [tf.contrib.lite.OpsSet.TFLITE_BUILTINS, tf.contrib.lite.OpsSet.SELECT_TF_OPS]
    tflite_model = converter.convert()
    with open("./tflite/"+modelname+".tflite", "wb") as fw:
        fw.write(tflite_model)

 

 

测试解释器代码:

 

import tensorflow as tf
import numpy as np
import math
import time

SIZE = 1000
X = np.random.rand(SIZE, 1, ).astype(np.float32)
X = X*(math.pi/2.0)

start = time.time()

interpreter = tf.lite.Interpreter(model_path="/home/alcht0/share/project/tensorflow-v1.12.0/converted_model.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.set_tensor(input_details[0]['index'], X)
interpreter.set_tensor(input_details[1]['index'], X)

interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
end = time.time()
print("1st ", str(end - start))

 

 

tf_lite的配置_第1张图片

这种是有问题的!!!!!

tf_lite的配置_第2张图片

这种说明模型没有问题,仅仅输入输出对不上而已

 

你可能感兴趣的:(tf_lite的配置)