官方给出的例子1:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md#example-2-export-with-variables
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
open("converteds_model.tflite", "wb").write(tflite_model)
运行之后报错:
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
1-003f6775a35f> in ()
4 out = tf.identity(val, name="out")
5 with tf.Session() as sess:
----> 6 tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
7 open("converteds_model.tflite", "wb").write(tflite_model)
~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py in toco_convert(input_data, input_tensors, output_tensors, inference_type, input_format, output_format, quantized_input_stats, drop_control_dependency)
210 data = toco_convert_protos(model.SerializeToString(),
211 toco.SerializeToString(),
--> 212 input_data.SerializeToString())
213 return data
214
~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py in toco_convert_protos(model_flags_str, toco_flags_str, input_data_str)
103 model_flags_str, toco_flags_str, input_data_str)
104
--> 105 with tempfile.NamedTemporaryFile() as fp_toco, \
106 tempfile.NamedTemporaryFile() as fp_model, \
107 tempfile.NamedTemporaryFile() as fp_input, \
NameError: name 'tempfile' is not defined
原因:It seems it has to do with the interface sealing in tensorflow(where we try to prevent exposing symbols unrelated to the interfaces)。
解决方案:manually put back imported modules:
import tensorflow as tf
#----------------Solution Begin----------------
# manually put back imported modules
import tempfile
import subprocess
tf.contrib.lite.tempfile = tempfile
tf.contrib.lite.subprocess = subprocess
#-----------------Solution End-----------------
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
open("converteds_model.tflite", "wb").write(tflite_model)
运行之后就能在工作目录下看到生成的converteds_model.tflite了。
官方给出的例子2:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md#example-2-export-with-variables
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3))
val = img + var
def canonical_name(x):
return x.name.split(":")[0]
out = tf.identity(val, name="out")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
out_tensors = [out]
frozen_graphdef = tf.graph_util.convert_variables_to_constants(
sess, sess.graph_def, map(canonical_name, out_tensors))
tflite_model = tf.contrib.lite.toco_convert(
frozen_graphdef, [img], out_tensors)
open("converted_model.tflite", "wb").write(tflite_model)
运行之后报错:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
3-f81c3f0d99d5> in ()
13 out_tensors = [out]
14 frozen_graphdef = tf.graph_util.convert_variables_to_constants(
---> 15 sess, sess.graph_def, map(canonical_name, out_tensors))
16 tflite_model = tf.contrib.lite.toco_convert(
17 frozen_graphdef, [img], out_tensors)
~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/graph_util_impl.py in convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist, variable_names_blacklist)
230 # This graph only includes the nodes needed to evaluate the output nodes, and
231 # removes unneeded nodes like those involved in saving and assignment.
--> 232 inference_graph = extract_sub_graph(input_graph_def, output_node_names)
233
234 found_variables = {}
~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes)
174 _assert_nodes_are_present(name_to_node, dest_nodes)
175
--> 176 nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
177
178 nodes_to_keep_list = sorted(
~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/graph_util_impl.py in _bfs_for_reachable_nodes(target_nodes, name_to_input_name)
138 nodes_to_keep = set()
139 # Breadth first search to find all the nodes that we should keep.
--> 140 next_to_visit = target_nodes[:]
141 while next_to_visit:
142 n = next_to_visit[0]
TypeError: 'map' object is not subscriptable
原因:笔者用的是Python3.6,直接用map(canonical_name, out_tensors))会报错
解决方案:将map(canonical_name, out_tensors))改为list(map(canonical_name, out_tensors)))
然后运行会报和例子1同样的错误,解决方法相同:
import tensorflow as tf
#----------------Solution Begin----------------
# manually put back imported modules
import tempfile
import subprocess
tf.contrib.lite.tempfile = tempfile
tf.contrib.lite.subprocess = subprocess
#-----------------Solution End-----------------
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3))
val = img + var
def canonical_name(x):
return x.name.split(":")[0]
out = tf.identity(val, name="out")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
out_tensors = [out]
frozen_graphdef = tf.graph_util.convert_variables_to_constants(
#----------------Solution Begin----------------
sess, sess.graph_def, list(map(canonical_name, out_tensors)))
#-----------------Solution End-----------------
tflite_model = tf.contrib.lite.toco_convert(
frozen_graphdef, [img], out_tensors)
open("converted_model.tflite", "wb").write(tflite_model)
在将训练好的模型转换为.tflite格式要经过两个步骤,首先是进行freeze_graph转换为.pb格式,再进行toco转换为.tflite格式。第一步进行freeze的时候不管是使用训练的graph.pbtxt还是使用测试用的eval.pbtxt都是可以转换成功的,因此一般进行freeze_graph的时候是不会报错的,错误会在第二步使用toco将.pb格式的模型转换为.tflite格式的模型的时候暴露出来。
使用toco将.pb格式的模型转换为.tflite格式的模型可能出现的错误形式如图1-3所示:
原因:
在训练好对模型(目前tensorflow训练生成模型格式是V2版本的,其形式如图4所示)进行转换的时候流程不对,直接使用了训练用的graph.pbtxt。正确的做法是先生成evaluate时用的eval.pbtxt,然后进行bazel-bin/tensorflow/python/tools/freeze_graph,最后再进行bazel-bin/tensorflow/contrib/lite/toco/toco。专门介绍从数据准备,到训练,到模型转换,再到部署到手机的整个流程的详见Tensorflow部署到移动端。
freeze_graph,下面的graph.pbtxt是train的时候生成的,用它来进行freeze是不会报错的,如图5上半部分所示:
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/home/lg/Desktop/mobilenet_v1/checkpoint/graph.pbtxt \
--input_checkpoint=/home/lg/Desktop/mobilenet_v1/checkpoint/model.ckpt-20000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_244.pb \
--output_node_names=MobilenetV1/Predictions/Reshape_1 \
--checkpoint_version=2
toco,企图利用graph.pbtxt freeze之后的.pb文件转换为.tflite文件:
企图转换为FLOAT的.tflite,报错如图5下半部分所示:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_244.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/mobilenet_v1/frozen_graph_mobilenet_v1.tflite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays='shuffle_batch/(shuffle_batch)' \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,244,244,3
企图转换为QUANTIZED_UINT8的.tflite,报错如图6所示:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_244.pb --input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/mobilenet_v1/frozen_graph_mobilenet_v1.tflite --inference_type=QUANTIZED_UINT8 \
--input_type=QUANTIZED_UINT8 \
--input_arrays='shuffle_batch/(shuffle_batch)' \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,244,244,3
首先产生eval.pbtxt才能保证后续的freeze_graph、toco的正确。如何产生eval.pbtxt详见Tensorflow部署到移动端。
利用eval.pbtxt进行正确的freeze graph:
freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/home/lg/Desktop/mobilenet_v1/mobilenet_v1_eval.pbtxt \
--input_checkpoint=/home/lg/Desktop/mobilenet_v1/checkpoint/model.ckpt-20000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_224.pb \
--output_node_names=MobilenetV1/Predictions/Reshape_1 \
--checkpoint_version=2
toco
利用eval.pbtxt进行正确的freeze graph之后生成正确的.pb文件,再对.pb文件进行toco:
保持FLOAT类型的toco操作:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_224.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/mobilenet_v1/frozen_graph_mobilenet_v1-FLOAT.tflite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=Placeholder \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,224,224,3
进行量化为QUANTIZED_UINT8的toco操作:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_224.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/mobilenet_v1/frozen_graph_mobilenet_v1-QUANTIZED_UINT8.tflite \
--inference_type=QUANTIZED_UINT8 \
--input_type=QUANTIZED_UINT8 \
--input_arrays=Placeholder \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,224,224,3 \
--default_ranges_min=0.0 \
–default_ranges_max=255.0
同样的,下面的graph.pbtxt也是train的时候生成的,用它来进行freeze也是不会报错的:
freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/home/lg/Desktop/inception_v3/checkpoint/graph.pbtxt \
--input_checkpoint=/home/lg/Desktop/inception_v3/checkpoint/model.ckpt-20000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb \
--output_node_names=InceptionV3/Predictions/Reshape_1 \
--checkpoint_version=2
toco,企图利用graph.pbtxt freeze之后的.pb文件转换为.tflite文件:
企图转换为FLOAT的.tflite,报错如图7所示:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/frozen_inception_v3_299.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/frozen_graph_inception_v3.tflite --inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=input \
--output_arrays=InceptionV3/Predictions/Reshape_1 \
--input_shapes=1,299,299,3
企图转换为QUANTIZED_UINT8 的.tflite,报错如图8所示:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/inception_v3/frozen_graph_inception_v3.tflite \
--inference_type=QUANTIZED_UINT8 \
--input_type=QUANTIZED_UINT8 \
--input_arrays=input \
--output_arrays=InceptionV3/Predictions/Reshape_1 \
--input_shapes=1,299,299,3
首先产生eval.pbtxt才能保证后续的freeze_graph、toco的正确。如何产生eval.pbtxt详见Tensorflow部署到移动端。
利用eval.pbtxt进行正确的freeze graph:
freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/home/lg/Desktop/inception_v3/inception_v3_eval.pbtxt \
--input_checkpoint=/home/lg/Desktop/inception_v3/checkpoint/model.ckpt-20000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb \
--output_node_names=InceptionV3/Predictions/Reshape_1 \
--checkpoint_version=2
toco
利用eval.pbtxt进行正确的freeze graph之后生成正确的.pb文件,再对.pb文件进行toco:
保持FLOAT类型的toco操作:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/inception_v3/frozen_graph_inception_v3.tflite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=Placeholder \
--output_arrays=InceptionV3/Predictions/Reshape_1 \
--input_shapes=1,299,299,3
进行量化为QUANTIZED_UINT8的toco操作:
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/home/lg/Desktop/inception_v3/frozen_graph_inception_v3.tflite \
--inference_type=QUANTIZED_UINT8 \
--input_type=QUANTIZED_UINT8 \
--input_arrays=Placeholder \
--output_arrays=InceptionV3/Predictions/Reshape_1 \
--input_shapes=1,299,299,3 \
--default_ranges_min=0.0 \
--default_ranges_max=255.0
更多资料请移步github:
https://github.com/GarryLau