tensorflow的freeze graph及inference graph_transforms

tensorflow模型优化

一般利用tensorflow训练出来的模型包含4个文件:

  • check_point文件,指定模型存放路径的文件
  • .data文件,存放模型的权重文件
  • .index文件,存放网络节点的索引
  • .meta文件,存放网络图结构

当我们在自己的电脑上实验时,固然可以先加载meta文件获得网络图结构,再加载.data文件加载权重,然后进行推理(inference)。但是在生产环境下,这样做就有些麻烦,况且有些模型还需要放在移动端,这就必须要优化模型。

1. 固化模型(将权重和图整合在一起,并去除与inference无关的节点)

源码位置:tensorflow/python/tools/freeze_graph.py
首先利用bazel编译tensorflow,bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
然后再编译freeze_graph,bazel build tensorflow/python/tools:freeze_graph
用法:

bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \    # 注意:这里的pb文件是用tf.train.write_graph方法保存的图结构
--input_checkpoint=model.ckpt \
--output_graph=frozen_graph.pb
--output_node_names=output_node

所以我们在训练的时候需要用saver.save保存ckpt,然后再用tf.train.write_graph保存图结构,代码如下:

with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.save(session, "model.ckpt")
    tf.train.write_graph(session.graph_def, '', 'graph.pb')

如果保存的图结构是.meta文件怎么办呢,也就是说用Saver.save方法和checkpoint一起生成的元模型文件,freeze_graph.py不适用,但可以改源码,然后重新用bazel编译。
修改freeze_graph.py里的_parse_input_graph_proto()函数

# 在freeze_graph.py最前面添加 from tensorflow.python.framework import meta_graph
# 将
input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)
 改为:
 input_graph_def = meta_graph.read_meta_graph_file(input_graph).graph_def

2.利用tensorflow自带优化工具实现优化(graph_transforms)

源码位置:tensorflow/tools/graph_transforms/
github地址
首先编译graph_transforms
bazel build tensorflow/tools/graph_transforms:transform_graph
查看graph_transforms可知它分为四个大类:

  • Optimizing for Deployment
  • Fixing Missing Kernel Errors on Mobile
  • Shrinking File Size
  • Eight-bit Calculations

2.1 Optimizing for Deployment

当我们在训练模型后,希望将其部署到服务器或移动设备上,并且希望它尽可能快地运行。该方法删除了推理过程中没有调用的所有节点,将始终不变的表达式缩减为单个节点,并通过对卷积的权值进行预乘,优化了batchnorm过程中使用的一些乘法操作。用法如下:

bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \     #已经固化后的训练模型
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
  strip_unused_nodes(type=float, shape="1,299,299,3")
  remove_nodes(op=Identity, op=CheckNumerics)
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms'

2.2 Fixing Missing Kernel Errors on Mobile

由于TensorFlow在Mobile中使用的时候,默认情况下是只能推理预测,可是在build so文件、jar文件或者.a文件的时候,依赖文件是写入在tensorflow / contrib / makefile / tf_op_files.txt中的,里面还包含了一些training相关的ops,这些可能会导致我们在加载PB文件的时候报错(No OpKernel was registered to support Op),这时候我们可以通过这个脚本来修复。

bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
  strip_unused_nodes(type=float, shape="1,299,299,3")
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms'

2.3 Shrinking File Size

如果我们需要将模型部署为移动应用程序的一部分,那么我们就需要减小模型的大小啦。对于大多数TensorFlow模型,文件大小的最大贡献者是传递给卷积和全连接层的权重,所以我们要减小模型大小,就得改变权重的存储方式;默认情况下,权值存储为32位浮点值,我们可以用四舍五入的方式存储权重。

bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
  strip_unused_nodes(type=float, shape="1,299,299,3")
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms
  round_weights(num_steps=256)'

我们也可以直接把权重量化为8位来存储,与round_weights相比,这种方法的缺点是插入了额外的解压缩操作,将8位值转换回浮点值,但是TensorFlow运行时中的优化应该确保缓存了这些结果,因此您不应该看到图形运行得更慢。优化后的模型大小应该约为原来模型的1/4。

bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
  strip_unused_nodes(type=float, shape="1,299,299,3")
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms
  quantize_weights'

2.4 Eight-bit Calculations

将推理的计算过程转换为8bit定点运算(还在实验开发阶段)

bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
  add_default_attributes
  strip_unused_nodes(type=float, shape="1,299,299,3")
  remove_nodes(op=Identity, op=CheckNumerics)
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms
  quantize_weights
  quantize_nodes
  strip_unused_nodes
  sort_by_execution_order'

该过程转换图中所有具有8位量化的操作,其余的操作保留在浮点数中。它只支持部分ops,而且在许多平台上,量化代码实际上可能比浮点代码慢,但是当所有条件都合适时,这是一种大幅度提高性能的方法。我管这个叫后量化推理操作,后面我会讲到在训练中加入量化节点的方式完成量化(Quantization-aware training,也叫fake quantization),见我的另一篇文章。

你可能感兴趣的:(Tensorflow)