一般利用tensorflow训练出来的模型包含4个文件:
当我们在自己的电脑上实验时,固然可以先加载meta文件获得网络图结构,再加载.data文件加载权重,然后进行推理(inference)。但是在生产环境下,这样做就有些麻烦,况且有些模型还需要放在移动端,这就必须要优化模型。
源码位置:tensorflow/python/tools/freeze_graph.py
首先利用bazel编译tensorflow,bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
然后再编译freeze_graph,bazel build tensorflow/python/tools:freeze_graph
用法:
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \ # 注意:这里的pb文件是用tf.train.write_graph方法保存的图结构
--input_checkpoint=model.ckpt \
--output_graph=frozen_graph.pb
--output_node_names=output_node
所以我们在训练的时候需要用saver.save保存ckpt,然后再用tf.train.write_graph保存图结构,代码如下:
with tf.Session() as sess:
saver = tf.train.Saver()
saver.save(session, "model.ckpt")
tf.train.write_graph(session.graph_def, '', 'graph.pb')
如果保存的图结构是.meta文件怎么办呢,也就是说用Saver.save方法和checkpoint一起生成的元模型文件,freeze_graph.py不适用,但可以改源码,然后重新用bazel编译。
修改freeze_graph.py里的_parse_input_graph_proto()函数
# 在freeze_graph.py最前面添加 from tensorflow.python.framework import meta_graph
# 将
input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
改为:
input_graph_def = meta_graph.read_meta_graph_file(input_graph).graph_def
源码位置:tensorflow/tools/graph_transforms/
github地址
首先编译graph_transforms
bazel build tensorflow/tools/graph_transforms:transform_graph
查看graph_transforms可知它分为四个大类:
当我们在训练模型后,希望将其部署到服务器或移动设备上,并且希望它尽可能快地运行。该方法删除了推理过程中没有调用的所有节点,将始终不变的表达式缩减为单个节点,并通过对卷积的权值进行预乘,优化了batchnorm过程中使用的一些乘法操作。用法如下:
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \ #已经固化后的训练模型
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms'
由于TensorFlow在Mobile中使用的时候,默认情况下是只能推理预测,可是在build so文件、jar文件或者.a文件的时候,依赖文件是写入在tensorflow / contrib / makefile / tf_op_files.txt中的,里面还包含了一些training相关的ops,这些可能会导致我们在加载PB文件的时候报错(No OpKernel was registered to support Op),这时候我们可以通过这个脚本来修复。
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms'
如果我们需要将模型部署为移动应用程序的一部分,那么我们就需要减小模型的大小啦。对于大多数TensorFlow模型,文件大小的最大贡献者是传递给卷积和全连接层的权重,所以我们要减小模型大小,就得改变权重的存储方式;默认情况下,权值存储为32位浮点值,我们可以用四舍五入的方式存储权重。
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms
round_weights(num_steps=256)'
我们也可以直接把权重量化为8位来存储,与round_weights相比,这种方法的缺点是插入了额外的解压缩操作,将8位值转换回浮点值,但是TensorFlow运行时中的优化应该确保缓存了这些结果,因此您不应该看到图形运行得更慢。优化后的模型大小应该约为原来模型的1/4。
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
strip_unused_nodes(type=float, shape="1,299,299,3")
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms
quantize_weights'
将推理的计算过程转换为8bit定点运算(还在实验开发阶段)
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='input,phase_train' \
--outputs='softmax' \
--transforms='
add_default_attributes
strip_unused_nodes(type=float, shape="1,299,299,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms
quantize_weights
quantize_nodes
strip_unused_nodes
sort_by_execution_order'
该过程转换图中所有具有8位量化的操作,其余的操作保留在浮点数中。它只支持部分ops,而且在许多平台上,量化代码实际上可能比浮点代码慢,但是当所有条件都合适时,这是一种大幅度提高性能的方法。我管这个叫后量化推理操作,后面我会讲到在训练中加入量化节点的方式完成量化(Quantization-aware training,也叫fake quantization),见我的另一篇文章。