《TF-Lite极简参考-模型转换》
TensorFlow Lite 可以很方便的把基于TensorFlow训练的模型进行转换,然后推理,在TensorFlow2.0中,keras被全面整合,可以使用tf.keras来更高效的构建模型,尽管前几天爆出TensorFlow2.0惊现大bug,并且一直被吐槽难用,但是受众依然很广,如果不用太多自定义的层,还是很稳定的。我大概是从TensorFlow 0.10版本开始用的,追了很久,也成功在服务端落地过很多OCR项目。
Key Words:TF lite、模型转换
Beijing, 2020
作者:RaySue
Code:https://github.com/RaySue/TF-Lite-Demo.git
Agile Pioneer
部署TensorFlow Lite模型文件使用:
TensorFlow Lite 提供以下三种模型转换方法:
def convert_to_pb(model, path, input_layer_name, output_layer_name, pbfilename, verbose=False):
model.load(path, weights_only=True)
print("[INFO] Loaded CNN network weights from " + path + " ...")
print("[INFO] Re-export model ...")
del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
model.save("model-tmp.tfl")
# taken from: https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow
print("[INFO] Re-import model ...")
input_checkpoint = "model-tmp.tfl"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', True)
sess = tf.Session();
saver.restore(sess, input_checkpoint)
# print out all layers to find name of output
if (verbose):
op = sess.graph.get_operations()
[print(m.values()) for m in op][1]
print("[INFO] Freeze model to " + pbfilename + " ...")
# freeze and removes nodes which are not related to feedforward prediction
minimal_graph = convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_layer_name])
graph_def = optimize_for_inference_lib.optimize_for_inference(minimal_graph, [input_layer_name],
[output_layer_name], tf.float32.as_datatype_enum)
graph_def = TransformGraph(graph_def, [input_layer_name], [output_layer_name], ["sort_by_execution_order"])
with tf.gfile.GFile(pbfilename, 'wb') as f:
f.write(graph_def.SerializeToString())
# write model to logs dir so we can visualize it as:
# tensorboard --logdir="logs"
if (verbose):
writer = tf.summary.FileWriter('logs', graph_def)
writer.close()
# tidy up tmp files
for f in glob.glob("model-tmp.tfl*"):
os.remove(f)
os.remove('checkpoint')
def convert_to_tflite(pbfilename, input_layer_name, output_layer_name,
input_tensor_dim_x, input_tensor_dim_y, input_tensor_channels=3):
input_tensor = {input_layer_name: [1, input_tensor_dim_x, input_tensor_dim_y, input_tensor_channels]}
print("[INFO] tflite model to " + pbfilename.replace(".pb", ".tflite") + " ...")
converter = tf.lite.TFLiteConverter.from_frozen_graph(pbfilename, [input_layer_name], [output_layer_name],
input_tensor)
tflite_model = converter.convert()
open(pbfilename.replace(".pb", ".tflite"), "wb").write(tflite_model)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open('/home/ai/converted_model.tflite', 'wb').write(tflite_model)
tflite_model_path = "xxx.tflite"
tflife_model = tf.lite.Interpreter(model_path=tflite_model_path)
tflife_model.allocate_tensors()
# Get input and output tensors.
tflife_input_details = tflife_model.get_input_details()
tflife_output_details = tflife_model.get_output_details()
frame = cv2.imread("xxx.jpg")
small_frame = cv2.resize(frame, (224, 224), cv2.INTER_AREA)
small_frame = np.expand_dims(small_frame, 0)
tflife_input_data = np.reshape(np.float32(small_frame), (1, 224, 224, 3))
tflife_model.set_tensor(tflife_input_details[0]['index'], tflife_input_data)
tflife_model.invoke()
output_tflite = tflife_model.get_tensor(tflife_output_details[0]['index'])