最近在做模型的量化,量化的模型是人脸检测网络mtcnn,我从Onet开始入手,原先这个模型使用的权重文件是ckpt,这种存储格式适合训练,如果要做量化的话,需要先转化为pb文件,把其中的变量都持久化。再进一步做量化
生成的思路是给加载ckpt文件的onet网络导入一张48x48的人头图像,输出softmax值和box数值,再把网络加载方式换成生成的pb文件,再送一样的一幅图进去,查看输出结果,一样则转化成功。然后接下来就可以在生成的pb文件上做int8量化。
pb是protocol(协议) buffer(缓冲)的缩写。TensorFlow训练模型后存成的pb文件,是一种表示模型(神经网络)结构的二进制文件,不带有源代码。
pb文件中可以只存参数,也可以存参数加网络结构,我们这里要生成的是存参数+网络结构,这样在推断的时候,可以不用重新在代码中定义网络结构,直接送入图像就可以输出结果,很方便。google现在也推荐这种文件格式。
我们在原网络中加载ckpt模型,然后回复成sess,再从sess保存到pb文件
代码如下:
import sys
import argparse
import time
import os
os.environ['CUDA_VISIBLE_DEVICES']='3'
import tensorflow as tf
import cv2
import numpy as np
from tensorflow.python.framework import graph_util
from src.mtcnn import PNet, RNet, ONet
from tools import detect_face, get_model_filenames
def main(args):
out_pb_path="onet_trained2.pb"
img = cv2.imread(args.image_path)
img48 = (img - 127.5) * (1. / 128.0)
img_x = np.expand_dims(img48, 0)
file_paths = get_model_filenames(args.model_dir)
with tf.device('/gpu:3'):
with tf.Graph().as_default():
config = tf.ConfigProto(allow_soft_placement=True)
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
with tf.Session(config=config) as sess:
if len(file_paths) == 3:
image_onet = tf.placeholder(tf.float32, [None, 48, 48, 3])
onet = ONet({'data': image_onet}, mode='test')
out_tensor_onet = onet.get_all_output()
saver_onet = tf.train.Saver(
[v for v in tf.global_variables()
if v.name[0:5] == "onet/"])
saver_onet.restore(sess, file_paths[2])
sess.run(out_tensor_onet, feed_dict={image_onet: img_x})
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
# for op in graph.get_operations():
# print(op.name, op.values())
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess = sess,
input_graph_def = input_graph_def,# 等于:sess.graph_def input_graph_def
output_node_names = ['softmax/softmax','onet/conv6-2/onet/conv6-2'])# 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(out_pb_path, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
代码中最关键的是输出节点名称的确定,只要写对了程序基本没有问题,我在这一块卡了好久。查节点的方法有直接看原网络的输出节点名称、可视化工具tensorflow、netron。我使用的是netron,很方便,在网页中上次模型文件即可,具体教程看博客:
https://blog.csdn.net/qqqzmy/article/details/86060131
在生成文件后先用netron看一下结构对不对
ckpt的结构如下,非常乱,包含很多我们推断过程中用不到的节点:
转成pb文件以后,就清爽多了:
除了必须的节点,没有多余的,加载速度快很多
我们再对生成的pb文件加载进网络,测试onet的输出结果是否和ckpt文件的输出结果一样
下面是ckpt文件的输出结果,上面是sotfmax的两个输出值,下面是box的四个输出值:
测试的代码如下:
import tensorflow as tf
#from create_tf_record import *
import os
from tensorflow.python.framework import graph_util
import cv2
import numpy as np
os.environ['CUDA_VISIBLE_DEVICES']='2'
#imgPath = "/1t_second/myzhuang2/quantization/mtcnn_tf_quant/img48x48.jpg"
model_path = "/1t_second/myzhuang2/quantization/mtcnn_tf_quant/onet.pb"
def freeze_graph_test(pb_path, image_path):
'''
:param pb_path:pb文件的路径
:param image_path:测试图片的路径
:return:
'''
with tf.device('/gpu:2'):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
# 定义输入的张量名称,对应网络结构的输入张量
input_image_tensor = sess.graph.get_tensor_by_name("Placeholder:0")
softmax = sess.graph.get_tensor_by_name("softmax/softmax:0")
conv62 = sess.graph.get_tensor_by_name("onet/conv6-2/onet/conv6-2:0")
# 定义输出的张量名称
img = cv2.imread(image_path)
img = (img - 127.5) * 0.0078125
img_x = np.expand_dims(img, 0)
img_x.astype(np.float32)
# 读取测试图片
# im=read_image(image_path,resize_height,resize_width,normalization=True)
# im=im[np.newaxis,:]
# 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
out1, out2 = sess.run([softmax, conv62], feed_dict = {input_image_tensor:img_x})
out_conv61 = np.array(out1)
out_conv62 = np.array(out2)
print("---------------")
print(out_conv61)
print("---------------")
print(out_conv62)
print("test done")
https://blog.csdn.net/yjl9122/article/details/78341689
https://blog.csdn.net/guyuealian/article/details/81560537
https://blog.csdn.net/michael_yt/article/details/74737489
https://blog.csdn.net/lujiandong1/article/details/53385092
https://blog.csdn.net/wc781708249/article/details/78043099
https://blog.csdn.net/guyuealian/article/details/82218092#commentBox