前段时间利用tensorflow的量化工具做了量化训练,精度损失很小,有时甚至比浮点模型精度更好一点,确实强大。利用tflite框架在3536上相比浮点模型有了2X左右的速度提升,现在做一个总结记录。
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize,
这是tensorflow的量化教程,主要是tf.contrib.quantize.create_training_graph()和tf.contrib.quantize.create_eval_graph() 这个两个函数,在训练的过程中学习max和min等参数,然后转化为定点的tflite就行了。
基本对应论文https://arxiv.org/abs/1712.05877描述的内容
------------------------------------------------------------------------
1、直接量化(这种量化方式,现在(2018年12月)的tensorflow代码的发行版中已经移除了(时间有可能更早,好久没关注了 手动滑稽.jpg),现在主推的是tf.contrib.quantize.create_training_graph这种量化训练方式,不过作为学习量化方法,还是有一定价值的,不感兴趣可以直接略过这一章节,我会在后续tensorflow的量化教程(2)中 介绍如何实现论文中的直接量化和量化训练 )
感兴趣的童鞋可以在这里下载相关的demo:https://download.csdn.net/download/u012101561/10843616,没有积分可以私信我,邮箱发你
想了解相关原理的,可以参考这里:http://fjdu.github.io/machine/learning/2016/07/07/quantize-neural-networks-with-tensorflow.html
不同于网上教程,我没有用bazel编译量化工具,当时用bazel build tensorflow/tools/quantization:quantize_graph时,一直出现 Error Download xxx package 的错误,可能是墙的原因。后来发现,其实可以直接运行python脚本,在tensorflow-master根目录下运行
cd tensorflow/tools/quantization
在该目录下,有BUILD graph_to_dot.py quantize_graph.py quantize_graph_test.py这几个文件,运行quantize_graph.py即可
以GoogLeNet 为例,下载模型,并解压
curl http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz -o /tmp/inceptionv3.tgz
tar xzf /tmp/inceptionv3.tgz
运行脚本
python quantize_graph.py --input=/tmp/classify_image_graph_def.pb --output_node_names="softmax" --output=/tmp/quantized_graph.pb --mode=eightbit
在/tmp目录下,就会生成quantized_graph.pb了。
模型由94M压缩到了24M
接下来我们测试一下,模型压缩后的性能如何
回到tensflow根目录上,编写一个脚本文件,内容如下,主要路径
echo "float model result:"
python tensorflow/tensorflow/examples/label_image/label_image.py \
--graph=/home/yqli/model/inception_v3/classify_image_graph_def.pb \
--input_width=299 \
--input_height=299 \
--input_mean=128 \
--input_std=128 \
--input_layer="Mul" \
--output_layer="softmax" \
--labels="/tmp/imagenet_synset_to_human_label_map.txt" \
--image="/tmp/cropped_panda.jpg"
echo "-----------------------------------------------------------------"
echo "quantized model result:"
python tensorflow/tensorflow/examples/label_image/label_image.py \
--graph=/home/yqli/model/inception_v3/quantized_graph.pb \
--input_width=299 \
--input_height=299 \
--input_mean=128 \
--input_std=128 \
--input_layer="Mul" \
--output_layer="softmax" \
--labels="/tmp/imagenet_synset_to_human_label_map.txt" \
--image="/tmp/cropped_panda.jpg"
如果不知道输入输出,可采用如下方式查看
import tensorflow as tf
gf = tf.GraphDef()
gf.ParseFromString(open('./testpb/test.pb','rb').read())
for n in gf.node:
print ( n.name +' ===> '+n.op )
结果如下: