https://github.com/onnx/tutorials/blob/master/tutorials/TensorflowToOnnx-1.ipynb
——Translated by Blssel
对Tensorflow和ONNX来说,虽然它们使用的是不同的计算图格式,但你可以使用Tensorflow-ONNX将一个Tensorflow模型转化为ONNX。本文将分为两个部分:第1部分介绍基本的转换方法,第2部分讨论更高级的话题。目录可以概括如下:
- 转换TensorFlow模型的步骤
-准备tensorflow模型
-转换为ONNX
-验证- 额外说明
步骤1:准备tensorflow模型
Tensorflow有好几种保存模型的文件格式,如检查点(checkpoint)文件、graph with weight(called frozen graph next) 以及saved_model,,你可以在训练模型时使用tensorflow提供的api来生成这些文件,可以参考脚本tensorflow_to_onnx_example.py
无论是这三种格式中的哪一种,Tensorflow-onnx都能够将它们转换成onnx格式不过更推荐使用“saved_model”格式,因为它不需要用户指定图形的输入和输出名称。本节将以它为例进行介绍,然后在第2部分(part2)中介绍其他两个。此外,你还可以从tensorflow-onnx的README文件中获得更多细节。
import os
import shutil
import tensorflow as tf
from assets.tensorflow_to_onnx_example import create_and_train_mnist
def save_model_to_saved_model(sess, input_tensor, output_tensor):
from tensorflow.saved_model import simple_save
save_path = r"./output/saved_model"
if os.path.exists(save_path):
shutil.rmtree(save_path)
simple_save(sess, save_path, {
input_tensor.name: input_tensor}, {
output_tensor.name: output_tensor})
print("please wait for a while, because the script will train MNIST from scratch")
tf.reset_default_graph()
sess_tf, saver, input_tensor, output_tensor = create_and_train_mnist()
print("save tensorflow in format \"saved_model\"")
save_model_to_saved_model(sess_tf, input_tensor, output_tensor)
please wait for a while, because the script will train MNIST from scratch
Extracting /tmp/tensorflow/mnist/input_data/train-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/train-labels-idx1-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-labels-idx1-ubyte.gz
step 0, training accuracy 0.18
step 1000, training accuracy 0.98
step 2000, training accuracy 0.94
step 3000, training accuracy 1
step 4000, training accuracy 1
test accuracy 0.976
save tensorflow in format "saved_model"
步骤2:转换为ONNX
tensorflow-onnx有几个条目用于转换不同的tensorflow格式的tensorflow模型,本节只讨论“saved_model”,“frozen graph”和“checkpoint”将在第2部分中介绍。
另外,tensorflow-onnx还导出了相关的python api,这样用户就可以直接在脚本中调用它们,而不是在命令行中调用它们,具体细节将在第2部分中介绍。
# generating mnist.onnx using saved_model
!python -m tf2onnx.convert \
--saved-model ./output/saved_model \
--output ./output/mnist1.onnx \
--opset 7
2019-06-17 07:22:03,871 - INFO - Using tensorflow=1.12.0, onnx=1.5.0, tf2onnx=1.5.1/0c735a
2019-06-17 07:22:03,871 - INFO - Using opset
2019-06-17 07:22:03,989 - INFO -
2019-06-17 07:22:04,012 - INFO - Optimizing ONNX model
2019-06-17 07:22:04,029 - INFO - After optimization: Add -2 (4->2), Identity -3 (3->0), Transpose -8 (9->1)
2019-06-17 07:22:04,031 - INFO -
2019-06-17 07:22:04,032 - INFO - Successfully converted TensorFlow model ./output/saved_model to ONNX
2019-06-17 07:22:04,044 - INFO - ONNX model is saved at ./output/mnist1.onnx
步骤3:验证
有好几种可以运行ONNX模型的方式,这里使用ONNXRuntime框架,由微软开源,可以确保生成的ONNX计算图正常运行。输入”image.npz”是一幅手写的“7”图像,因此模型的预期分类结果应为“7”。
import numpy as np
import onnxruntime as ort
img = np.load("./assets/image.npz").reshape([1, 784])
sess_ort = ort.InferenceSession("./output/mnist1.onnx")
res = sess_ort.run(output_names=[output_tensor.name], input_feed={
input_tensor.name: img})
print("the expected result is \"7\"")
print("the digit is classified as \"%s\" in ONNXRruntime"%np.argmax(res))
the expected result is "7"
the digit is classified as "7" in ONNXRruntime
以上的命令行应该适用于大多数tensorflow模型。在某些情况下,您可能会遇到需要额外选项的问题。
选项中最重要的概念是opset(操作集)选项,ONNX是一个不断发展的标准,它将添加更多的新操作并增强现有的操作,因此不同的opset版本将包含不同的操作,它们可能会有些不同 。默认版本“tensorflow-onnx”使用的是7,ONNX现在最高支持版本10,所以如果转换失败,您可以尝试不同的版本,通过命令行选项“——opset”,看看它是否工作。
继续第2部分,解释高级主题。