tensorflow的freeze_graph方法

前言:一般利用tensorflow训练出来的结果都是权重和graph分开保存的,在发布产品的时候,我们需要把权重和graph固化在一起,这里就会用到freeze_graph啦。

freeze_graph目前主要有两种方式:一种是通过bazel编译tensorflow/python/tools/freeze_graph.py文件,然后利用命令的方式把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的ckpt文件固化,最后重新生成一个pb文件,见我的上一篇文章。另一种方式是通过tensorflow的python API把变量转成常量之后写入PB文件中,这种方式得我们自己写代码完成啦。下面我们介绍这种方式,直接上代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import graph_util
import tensorflow as tf
import argparse
import os
import sys
from six.moves import xrange

def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_idx
    with tf.Graph().as_default():
        with tf.Session() as sess:
            # Load the model metagraph and checkpoint
            print('Metagraph file: %s' % args.meta_file)
            print('Checkpoint file: %s' % args.ckpt_file)
            saver = tf.train.import_meta_graph(args.meta_file, clear_devices=True)
            tf.get_default_session().run(tf.global_variables_initializer())
            tf.get_default_session().run(tf.local_variables_initializer())
            saver.restore(tf.get_default_session(), args.ckpt_file)
            
            # Retrieve the protobuf graph definition and fix the batch norm nodes
            input_graph_def = sess.graph.as_graph_def()

            # Freeze the graph def,softmax节点根据你自己需要的输出节点作相应更改
            output_graph_def = freeze_graph_def(sess, input_graph_def, 'softmax')

        # Serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(args.output_file, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph: %s" % (len(output_graph_def.node), args.output_file))


def freeze_graph_def(sess, input_graph_def, output_node_names):
    """
    :param sess: 
    :param input_graph_def: tensorflow input graph
    :param output_node_names: string, output node names
    :return: The transformed tensorflow graph
    """
    for node in input_graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr: del node.attr['use_locking']
    
    # Get the list of important nodes,白名单,根据自己的网络结构作相应更改
    whitelist_names = []
    for node in input_graph_def.node:
        if (node.name.startswith('MobilenetV2') or node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or
                node.name.startswith('dropoutprob') or node.name.startswith('xinput')):
            whitelist_names.append(node.name)

    # Replace all the variables in the graph with constants of the same values
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","),
        variable_names_whitelist=whitelist_names)
    return output_graph_def
  
def parse_arguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu_idx', type=str, help='gpu indexs', default='0')
    parser.add_argument('--ckpt_file', type=str, 
        help='Path to the checkpoint (ckpt) file containing model parameters')
    parser.add_argument('--meta_file', type=str,
        help='Path to the metagraph (.meta) file')
    parser.add_argument('--output_file', type=str, 
        help='Filename for the exported graphdef protobuf (.pb)')
    return parser.parse_args(argv)

if __name__ == '__main__':
    main(parse_arguments(sys.argv[1:]))

这种方式主要调用的接口就是代码中freeze_graph_def函数中的graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names.split(","), variable_names_whitelist=whitelist_names),它负责把原本tensorflow训练时保存的权重为变量转变为常量。

你可能感兴趣的:(Tensorflow,tensorflow固化模型,freeze_graph)