生成.tflite文件过程中遇到的问题及解决方案

问题一:

官方给出的例子1:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md#example-2-export-with-variables

import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
    tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
    open("converteds_model.tflite", "wb").write(tflite_model)

运行之后报错:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
1-003f6775a35f> in ()
      4 out = tf.identity(val, name="out")
      5 with tf.Session() as sess:
----> 6     tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
      7     open("converteds_model.tflite", "wb").write(tflite_model)

~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py in toco_convert(input_data, input_tensors, output_tensors, inference_type, input_format, output_format, quantized_input_stats, drop_control_dependency)
    210   data = toco_convert_protos(model.SerializeToString(),
    211                              toco.SerializeToString(),
--> 212                              input_data.SerializeToString())
    213   return data
    214 

~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py in toco_convert_protos(model_flags_str, toco_flags_str, input_data_str)
    103         model_flags_str, toco_flags_str, input_data_str)
    104 
--> 105   with tempfile.NamedTemporaryFile() as fp_toco, \
    106            tempfile.NamedTemporaryFile() as fp_model, \
    107            tempfile.NamedTemporaryFile() as fp_input, \

NameError: name 'tempfile' is not defined

原因:It seems it has to do with the interface sealing in tensorflow(where we try to prevent exposing symbols unrelated to the interfaces)。
解决方案:manually put back imported modules:

import tensorflow as tf

#----------------Solution Begin----------------
# manually put back imported modules
import tempfile
import subprocess
tf.contrib.lite.tempfile = tempfile
tf.contrib.lite.subprocess = subprocess
#-----------------Solution End-----------------

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
    tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
    open("converteds_model.tflite", "wb").write(tflite_model)

运行之后就能在工作目录下看到生成的converteds_model.tflite了。

问题二:

官方给出的例子2:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md#example-2-export-with-variables

import tensorflow as tf

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3))
val = img + var

def canonical_name(x):
  return x.name.split(":")[0]

out = tf.identity(val, name="out")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  out_tensors = [out]
  frozen_graphdef = tf.graph_util.convert_variables_to_constants(
      sess, sess.graph_def, map(canonical_name, out_tensors))
  tflite_model = tf.contrib.lite.toco_convert(
      frozen_graphdef, [img], out_tensors)
  open("converted_model.tflite", "wb").write(tflite_model)

运行之后报错:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
3-f81c3f0d99d5> in ()
     13     out_tensors = [out]
     14     frozen_graphdef = tf.graph_util.convert_variables_to_constants(
---> 15       sess, sess.graph_def, map(canonical_name, out_tensors))
     16     tflite_model = tf.contrib.lite.toco_convert(
     17       frozen_graphdef, [img], out_tensors)

~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/graph_util_impl.py in convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist, variable_names_blacklist)
    230   # This graph only includes the nodes needed to evaluate the output nodes, and
    231   # removes unneeded nodes like those involved in saving and assignment.
--> 232   inference_graph = extract_sub_graph(input_graph_def, output_node_names)
    233 
    234   found_variables = {}

~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes)
    174   _assert_nodes_are_present(name_to_node, dest_nodes)
    175 
--> 176   nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
    177 
    178   nodes_to_keep_list = sorted(

~/.conda/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/graph_util_impl.py in _bfs_for_reachable_nodes(target_nodes, name_to_input_name)
    138   nodes_to_keep = set()
    139   # Breadth first search to find all the nodes that we should keep.
--> 140   next_to_visit = target_nodes[:]
    141   while next_to_visit:
    142     n = next_to_visit[0]

TypeError: 'map' object is not subscriptable

原因:笔者用的是Python3.6,直接用map(canonical_name, out_tensors))会报错
解决方案:将map(canonical_name, out_tensors))改为list(map(canonical_name, out_tensors)))
然后运行会报和例子1同样的错误,解决方法相同:

import tensorflow as tf

#----------------Solution Begin----------------
# manually put back imported modules
import tempfile
import subprocess
tf.contrib.lite.tempfile = tempfile
tf.contrib.lite.subprocess = subprocess
#-----------------Solution End-----------------

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3))
val = img + var

def canonical_name(x):
    return x.name.split(":")[0]

out = tf.identity(val, name="out")
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    out_tensors = [out]
    frozen_graphdef = tf.graph_util.convert_variables_to_constants(
    #----------------Solution Begin----------------
      sess, sess.graph_def, list(map(canonical_name, out_tensors)))
    #-----------------Solution End-----------------
    tflite_model = tf.contrib.lite.toco_convert(
      frozen_graphdef, [img], out_tensors)
    open("converted_model.tflite", "wb").write(tflite_model)

问题三:.tflite转换失败

在将训练好的模型转换为.tflite格式要经过两个步骤,首先是进行freeze_graph转换为.pb格式,再进行toco转换为.tflite格式。第一步进行freeze的时候不管是使用训练的graph.pbtxt还是使用测试用的eval.pbtxt都是可以转换成功的,因此一般进行freeze_graph的时候是不会报错的,错误会在第二步使用toco将.pb格式的模型转换为.tflite格式的模型的时候暴露出来。
使用toco将.pb格式的模型转换为.tflite格式的模型可能出现的错误形式如图1-3所示:

生成.tflite文件过程中遇到的问题及解决方案_第1张图片
图1 toco转换失败-错误形式1
生成.tflite文件过程中遇到的问题及解决方案_第2张图片
图2 toco转换失败-错误形式2
生成.tflite文件过程中遇到的问题及解决方案_第3张图片
图3 toco转换失败-错误形式3

原因:
在训练好对模型(目前tensorflow训练生成模型格式是V2版本的,其形式如图4所示)进行转换的时候流程不对,直接使用了训练用的graph.pbtxt。正确的做法是先生成evaluate时用的eval.pbtxt,然后进行bazel-bin/tensorflow/python/tools/freeze_graph,最后再进行bazel-bin/tensorflow/contrib/lite/toco/toco。专门介绍从数据准备,到训练,到模型转换,再到部署到手机的整个流程的详见Tensorflow部署到移动端。

生成.tflite文件过程中遇到的问题及解决方案_第4张图片
图4

更详细介绍错误产生过程

mobilenet_v1

freeze_graph,下面的graph.pbtxt是train的时候生成的,用它来进行freeze是不会报错的,如图5上半部分所示:

bazel-bin/tensorflow/python/tools/freeze_graph  \
--input_graph=/home/lg/Desktop/mobilenet_v1/checkpoint/graph.pbtxt \
--input_checkpoint=/home/lg/Desktop/mobilenet_v1/checkpoint/model.ckpt-20000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_244.pb  \
--output_node_names=MobilenetV1/Predictions/Reshape_1  \
--checkpoint_version=2

toco,企图利用graph.pbtxt freeze之后的.pb文件转换为.tflite文件:
企图转换为FLOAT的.tflite,报错如图5下半部分所示:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_244.pb \
--input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/mobilenet_v1/frozen_graph_mobilenet_v1.tflite \
--inference_type=FLOAT  \
--input_type=FLOAT \
--input_arrays='shuffle_batch/(shuffle_batch)'  \
--output_arrays=MobilenetV1/Predictions/Reshape_1 \
--input_shapes=1,244,244,3

企图转换为QUANTIZED_UINT8的.tflite,报错如图6所示:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_244.pb --input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/mobilenet_v1/frozen_graph_mobilenet_v1.tflite --inference_type=QUANTIZED_UINT8  \
--input_type=QUANTIZED_UINT8 \
--input_arrays='shuffle_batch/(shuffle_batch)'  \
--output_arrays=MobilenetV1/Predictions/Reshape_1  \
--input_shapes=1,244,244,3
生成.tflite文件过程中遇到的问题及解决方案_第5张图片
图5 错误的利用train的graph.pbtxt进行freeze_graph、toco
生成.tflite文件过程中遇到的问题及解决方案_第6张图片
图6 错误的利用train的graph.pbtxt进行freeze_graph、toco
正确的做法

首先产生eval.pbtxt才能保证后续的freeze_graph、toco的正确。如何产生eval.pbtxt详见Tensorflow部署到移动端。
利用eval.pbtxt进行正确的freeze graph:
freeze_graph

bazel-bin/tensorflow/python/tools/freeze_graph  \
--input_graph=/home/lg/Desktop/mobilenet_v1/mobilenet_v1_eval.pbtxt \
--input_checkpoint=/home/lg/Desktop/mobilenet_v1/checkpoint/model.ckpt-20000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_224.pb  \
--output_node_names=MobilenetV1/Predictions/Reshape_1  \
--checkpoint_version=2

toco
利用eval.pbtxt进行正确的freeze graph之后生成正确的.pb文件,再对.pb文件进行toco:
保持FLOAT类型的toco操作:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_224.pb \
--input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/mobilenet_v1/frozen_graph_mobilenet_v1-FLOAT.tflite \
--inference_type=FLOAT  \
--input_type=FLOAT \
--input_arrays=Placeholder  \
--output_arrays=MobilenetV1/Predictions/Reshape_1  \
--input_shapes=1,224,224,3

进行量化为QUANTIZED_UINT8的toco操作:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/mobilenet_v1/frozen_mobilenet_v1_224.pb \
--input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/mobilenet_v1/frozen_graph_mobilenet_v1-QUANTIZED_UINT8.tflite \
--inference_type=QUANTIZED_UINT8  \
--input_type=QUANTIZED_UINT8 \
--input_arrays=Placeholder  \
--output_arrays=MobilenetV1/Predictions/Reshape_1  \
--input_shapes=1,224,224,3 \
--default_ranges_min=0.0 \
–default_ranges_max=255.0

Inception_v3

同样的,下面的graph.pbtxt也是train的时候生成的,用它来进行freeze也是不会报错的:
freeze_graph

bazel-bin/tensorflow/python/tools/freeze_graph  \
--input_graph=/home/lg/Desktop/inception_v3/checkpoint/graph.pbtxt \
--input_checkpoint=/home/lg/Desktop/inception_v3/checkpoint/model.ckpt-20000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb  \
--output_node_names=InceptionV3/Predictions/Reshape_1  \
--checkpoint_version=2

toco,企图利用graph.pbtxt freeze之后的.pb文件转换为.tflite文件:
企图转换为FLOAT的.tflite,报错如图7所示:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/frozen_inception_v3_299.pb \
--input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/frozen_graph_inception_v3.tflite --inference_type=FLOAT  \
--input_type=FLOAT \
--input_arrays=input  \
--output_arrays=InceptionV3/Predictions/Reshape_1  \
--input_shapes=1,299,299,3
生成.tflite文件过程中遇到的问题及解决方案_第7张图片
图7 错误的利用train的graph.pbtxt进行freeze_graph、toco

企图转换为QUANTIZED_UINT8 的.tflite,报错如图8所示:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb \
--input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/inception_v3/frozen_graph_inception_v3.tflite \
--inference_type=QUANTIZED_UINT8   \
--input_type=QUANTIZED_UINT8  \
--input_arrays=input  \
--output_arrays=InceptionV3/Predictions/Reshape_1  \
--input_shapes=1,299,299,3
生成.tflite文件过程中遇到的问题及解决方案_第8张图片
图8 错误的利用train的graph.pbtxt进行freeze_graph、toco
正确的做法

首先产生eval.pbtxt才能保证后续的freeze_graph、toco的正确。如何产生eval.pbtxt详见Tensorflow部署到移动端。
利用eval.pbtxt进行正确的freeze graph:
freeze_graph

bazel-bin/tensorflow/python/tools/freeze_graph  \
--input_graph=/home/lg/Desktop/inception_v3/inception_v3_eval.pbtxt \
--input_checkpoint=/home/lg/Desktop/inception_v3/checkpoint/model.ckpt-20000 \
--input_binary=false \
--output_graph=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb  \
--output_node_names=InceptionV3/Predictions/Reshape_1  \
--checkpoint_version=2

toco
利用eval.pbtxt进行正确的freeze graph之后生成正确的.pb文件,再对.pb文件进行toco:
保持FLOAT类型的toco操作:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb \
--input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/inception_v3/frozen_graph_inception_v3.tflite \
--inference_type=FLOAT  \
--input_type=FLOAT \
--input_arrays=Placeholder  \
--output_arrays=InceptionV3/Predictions/Reshape_1  \
--input_shapes=1,299,299,3

进行量化为QUANTIZED_UINT8的toco操作:

bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=/home/lg/Desktop/inception_v3/frozen_inception_v3_299.pb \
--input_format=TENSORFLOW_GRAPHDEF  \
--output_format=TFLITE  \
--output_file=/home/lg/Desktop/inception_v3/frozen_graph_inception_v3.tflite \
--inference_type=QUANTIZED_UINT8  \
--input_type=QUANTIZED_UINT8 \
--input_arrays=Placeholder  \
--output_arrays=InceptionV3/Predictions/Reshape_1  \
--input_shapes=1,299,299,3 \
--default_ranges_min=0.0 \
--default_ranges_max=255.0

参考文献

  1. https://github.com/tensorflow/tensorflow/issues/15410
  2. https://blog.csdn.net/liugan528/article/details/80044455
生成.tflite文件过程中遇到的问题及解决方案_第9张图片

更多资料请移步github:
https://github.com/GarryLau

你可能感兴趣的:(Tensorflow,Lite)