TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式

        本文承接上文 TensorFlow-slim 训练 CNN 分类模型(续),阐述通过 tf.contrib.slim 的函数 slim.learning.train 训练的模型,怎么通过人为的加入数据入口(即占位符)来克服无法用于图像推断的问题。要解决这个问题,最简单和最省时的方法是模仿。我们模仿的代码是 TensorFlow 实现的目标检测 API 中的文件 exporter.py,该文件的目的正是要将 TensorFlow-slim 训练的目标检测模型由 .ckpt 格式转化为.pb 格式,而且其代码中人为添加占位符的操作也正是我们需求的。坦白的说,我会用 TensorFlow 的 tf.contrib.slim 模块来构建和训练模型正是受 TensorFlow models 项目的影响,当时我需要训练目标检测器,因此变配置了 models 这个子项目,并且从头到尾的阅读了其中 object_detection 中的 Faster RCNN 的源代码,切实感受到了 slim 模块的简便和高效(学习 TensorFlow 最好的办法除了查阅文档之外,便是看 models 中各种项目的源代码)。

        言归正传,现在我们回到主题,怎么加入占位符,将前一篇文章训练的 CNN 分类器用于图像分类。这个问题在我们知道通过模仿 exporter.py 就可以解决它的时候,就变得异常简单了。我们先来理顺一下解决这个问题的逻辑:

1.定义数据入口,即定义占位符 inputs = tf.placeholder(···);
2.将模型作用于占位符,得到数据出口,即分类结果;
3.将训练文件从 .ckpt 格式转化为 .pb 格式。

按照这个逻辑顺序,下面我们详细的来看一下自定义模型导出,即模型格式转化的代码(命名为 exporter.py,如果没有特别说明,exporter.py 指的都是我们修改 TensorFlow 目标检测中的 exporter.py 后的自定义文件):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:13:27 2018
@author: shirhe-lyh
"""

"""Functions to export inference graph.
Modified from: TensorFlow models/research/object_detection/export.py
"""

import logging
import os
import tempfile
import tensorflow as tf

from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib

slim = tf.contrib.slim


# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common.
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
    variable_names_blacklist=''):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.
    
    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError(
            "Input checkpoint ' + input_checkpoint + ' does not exist!")
        
    if not output_node_names:
        raise ValueError(
            'You must supply the name of a node to --output_node_names.')
        
    # Remove all the explicit device specifications for this node. This helps
    # to make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ''
    
    with tf.Graph().as_default():
        tf.import_graph_def(input_graph_def, name='')
        config = tf.ConfigProto(graph_options=tf.GraphOptions())
        with session.Session(config=config) as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def)
                saver.restore(sess, input_checkpoint)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(
                    input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ':0')
                    except KeyError:
                        # This tensor doesn't exist in the graph (for example
                        # it's 'global_step' or a similar housekeeping element)
                        # so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes)
            
            variable_names_blacklist = (variable_names_blacklist.split(',') if
                                        variable_names_blacklist else None)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(','),
                variable_names_blacklist=variable_names_blacklist)
    return output_graph_def


def replace_variable_values_with_moving_averages(graph,
                                                 current_checkpoint_file,
                                                 new_checkpoint_file):
    """Replaces variable values in the checkpoint with their moving averages.
    
    If the current checkpoint has shadow variables maintaining moving averages
    of the variables defined in the graph, this function generates a new
    checkpoint where the variables contain the values of their moving averages.
    
    Args:
        graph: A tf.Graph object.
        current_checkpoint_file: A checkpoint both original variables and
            their moving averages.
        new_checkpoint_file: File path to write a new checkpoint.
    """
    with graph.as_default():
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        ema_variables_to_restore = variable_averages.variables_to_restore()
        with tf.Session() as sess:
            read_saver = tf.train.Saver(ema_variables_to_restore)
            read_saver.restore(sess, current_checkpoint_file)
            write_saver = tf.train.Saver()
            write_saver.save(sess, new_checkpoint_file)


def _image_tensor_input_placeholder(input_shape=None):
    """Returns input placeholder and a 4-D uint8 image tensor."""
    if input_shape is None:
        input_shape = (None, None, None, 3)
    input_tensor = tf.placeholder(
        dtype=tf.uint8, shape=input_shape, name='image_tensor')
    return input_tensor, input_tensor


def _encoded_image_string_tensor_input_placeholder():
    """Returns input that accepts a batch of PNG or JPEG strings.
    
    Returns:
        A tuple of input placeholder and the output decoded images.
    """
    batch_image_str_placeholder = tf.placeholder(
        dtype=tf.string,
        shape=[None],
        name='encoded_image_string_tensor')
    def decode(encoded_image_string_tensor):
        image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                             channels=3)
        image_tensor.set_shape((None, None, 3))
        return image_tensor
    return (batch_image_str_placeholder,
            tf.map_fn(
                decode,
                elems=batch_image_str_placeholder,
                dtype=tf.uint8,
                parallel_iterations=32,
                back_prop=False))


input_placeholder_fn_map = {
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
        _encoded_image_string_tensor_input_placeholder,
#    'tf_example': _tf_example_input_placeholder,
    }


def _add_output_tensor_nodes(postprocessed_tensors,
                             output_collection_name='inference_op'):
    """Adds output nodes.
    
    Adjust according to specified implementations.
    
    Adds the following nodes for output tensors:
        * classes: A float32 tensor of shape [batch_size] containing class
            predictions.
    
    Args:
        postprocessed_tensors: A dictionary containing the following fields:
            'classes': [batch_size].
        output_collection_name: Name of collection to add output tensors to.
        
    Returns:
        A tensor dict containing the added output tensor nodes.
    """
    outputs = {}
    classes = postprocessed_tensors.get('classes') # Assume containing 'classes'
    outputs['classes'] = tf.identity(classes, name='classes')
    for output_key in outputs:
        tf.add_to_collection(output_collection_name, outputs[output_key])
    return outputs


def write_frozen_graph(frozen_graph_path, frozen_graph_def):
    """Writes frozen graph to disk.
    
    Args:
        frozen_graph_path: Path to write inference graph.
        frozen_graph_def: tf.GraphDef holding frozen graph.
    """
    with gfile.GFile(frozen_graph_path, 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())
    logging.info('%d ops in the final graph.', len(frozen_graph_def.node))
    
    
def write_saved_model(saved_model_path,
                      frozen_graph_def,
                      inputs,
                      outputs):
    """Writes SavedModel to disk.
    
    If checkpoint_path is not None bakes the weights into the graph thereby
    eliminating the need of checkpoint files during inference. If the model
    was trained with moving averages, setting use_moving_averages to True
    restores the moving averages, otherwise the original set of variables
    is restored.
    
    Args:
        saved_model_path: Path to write SavedModel.
        frozen_graph_def: tf.GraphDef holding frozen graph.
        inputs: The input image tensor.
        outputs: A tensor dictionary containing the outputs of a slim model.
    """
    with tf.Graph().as_default():
        with session.Session() as sess:
            tf.import_graph_def(frozen_graph_def, name='')
            
            builder = tf.saved_model.builder.SavedModelBuilder(
                saved_model_path)
            
            tensor_info_inputs = {
                'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
            tensor_info_outputs = {}
            for k, v in outputs.items():
                tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(
                    v)
                
            detection_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs=tensor_info_inputs,
                    outputs=tensor_info_outputs,
                    method_name=signature_constants.PREDICT_METHOD_NAME))
            
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map={
                    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                        detection_signature,
                        },
            )
            builder.save()


def write_graph_and_checkpoint(inference_graph_def,
                               model_path,
                               input_saver_def,
                               trained_checkpoint_prefix):
    """Writes the graph and the checkpoint into disk."""
    for node in inference_graph_def.node:
        node.device = ''
    with tf.Graph().as_default():
        tf.import_graph_def(inference_graph_def, name='')
        with session.Session() as sess:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    save_relative_paths=True)
            saver.restore(sess, trained_checkpoint_prefix)
            saver.save(sess, model_path)


def _get_outputs_from_inputs(input_tensors, model, 
                             output_collection_name):
    inputs = tf.to_float(input_tensors)
    preprocessed_inputs = model.preprocess(inputs)
    output_tensors = model.predict(preprocessed_inputs)
    postprocessed_tensors = model.postprocess(output_tensors)
    return _add_output_tensor_nodes(postprocessed_tensors,
                                    output_collection_name)
    
    
def _build_model_graph(input_type, model, input_shape, 
                           output_collection_name, graph_hook_fn):
    """Build the desired graph."""
    if input_type not in input_placeholder_fn_map:
        raise ValueError('Unknown input type: {}'.format(input_type))
    placeholder_args = {}
    if input_shape is not None:
        if input_type != 'image_tensor':
            raise ValueError("Can only specify input shape for 'image_tensor' "
                             'inputs.')
        placeholder_args['input_shape'] = input_shape
    placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
        **placeholder_args)
    outputs = _get_outputs_from_inputs(
        input_tensors=input_tensors,
        model=model,
        output_collection_name=output_collection_name)
    
    # Add global step to the graph
    slim.get_or_create_global_step()
    
    if graph_hook_fn: graph_hook_fn()
    
    return outputs, placeholder_tensor


def export_inference_graph(input_type,
                           model,
                           trained_checkpoint_prefix,
                           output_directory,
                           input_shape=None,
                           use_moving_averages=None,
                           output_collection_name='inference_op',
                           additional_output_tensor_names=None,
                           graph_hook_fn=None):
    """Exports inference graph for the desired graph.
    
    Args:
        input_type: Type of input for the graph. Can be one of ['image_tensor',
            'encoded_image_string_tensor', 'tf_example']. In this file, 
            input_type must be 'image_tensor'.
        model: A model defined by model.py.
        trained_checkpoint_prefix: Path to the trained checkpoint file.
        output_directory: Path to write outputs.
        input_shape: Sets a fixed shape for an 'image_tensor' input. If not
            specified, will default to [None, None, None, 3].
        use_moving_averages: A boolean indicating whether the 
            tf.train.ExponentialMovingAverage should be used or not.
        output_collection_name: Name of collection to add output tensors to.
            If None, does not add output tensors to a collection.
        additional_output_tensor_names: List of additional output tensors to
            include in the frozen graph.
    """
    tf.gfile.MakeDirs(output_directory)
    frozen_graph_path = os.path.join(output_directory,
                                     'frozen_inference_graph.pb')
    saved_model_path = os.path.join(output_directory, 'saved_model')
    model_path = os.path.join(output_directory, 'model.ckpt')
    
    outputs, placeholder_tensor = _build_model_graph(
        input_type=input_type,
        model=model,
        input_shape=input_shape,
        output_collection_name=output_collection_name,
        graph_hook_fn=graph_hook_fn)
    
    saver_kwargs = {}
    if use_moving_averages:
        # This check is to be compatible with both version of SaverDef.
        if os.path.isfile(trained_checkpoint_prefix):
            saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
            temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
        else:
            temp_checkpoint_prefix = tempfile.mkdtemp()
        replace_variable_values_with_moving_averages(
            tf.get_default_graph(), trained_checkpoint_prefix,
            temp_checkpoint_prefix)
        checkpoint_to_use = temp_checkpoint_prefix
    else:
        checkpoint_to_use = trained_checkpoint_prefix
    
    saver = tf.train.Saver(**saver_kwargs)
    input_saver_def = saver.as_saver_def()
    
    write_graph_and_checkpoint(
        inference_graph_def=tf.get_default_graph().as_graph_def(),
        model_path=model_path,
        input_saver_def=input_saver_def,
        trained_checkpoint_prefix=checkpoint_to_use)
    
    if additional_output_tensor_names is not None:
        output_node_names = ','.join(outputs.keys()+
                                     additional_output_tensor_names)
    else:
        output_node_names = ','.join(outputs.keys())
        
    frozen_graph_def = freeze_graph_with_def_protos(
        input_graph_def=tf.get_default_graph().as_graph_def(),
        input_saver_def=input_saver_def,
        input_checkpoint=checkpoint_to_use,
        output_node_names=output_node_names,
        restore_op_name='save/restore_all',
        filename_tensor_name='save/Const:0',
        clear_devices=True,
        initializer_nodes='')
    write_frozen_graph(frozen_graph_path, frozen_graph_def)
    write_saved_model(saved_model_path, frozen_graph_def,
                      placeholder_tensor, outputs)

首先看定义占位符的函数 _image_tensor_input_placeholder_encoded_image_string_tensor_input_placeholder ,重点关注前一个函数,因为它的输入为一个批量图像组成的 4 维张量(正是我们需要的),这个函数仅仅定义了一个图像占位符 input_tensor

input_tensor = tf.placeholder(dtype=tf.uint8, shape=input_shape, name='image_tensor')

简单至极。接下来看 _build_model_graph 函数,这个函数将数据输入 input_tensor (第一个参数)通过模型 model (第二个参数)作用的结果 outputs 返回。其中引用的函数 _get_outputs_from_inputs,顾名思义,由输入数据得到分类结果。它又引用了函数 _add_output_tensor_nodes,这个函数比较重要,因为它定义了数据输出结点

outputs['classes'] = tf.identity(classes, name='classes')

以上这些便是这个自定义文件 exporter.py 的精华,因为它实现了数据入口(name='image_tensor')和出口(name='classes')结点的定义。另一方面,这个自定义文件 exporter.py 可以作为模型导出的通用文件,而针对每一个特定的模型我们只需要修改与参数 model(表示某个特定模型) 相关的函数即可,而所有这些函数就是以上列出的函数。

        为了描述的完整性,也来看一看剩下的不需要修改的函数。我们从主函数 export_inference_graph 开始,它是实际被调用的函数。它首先创建了用于保存输出文件的文件夹,然后根据参数 model 创建了模型数据入口和出口,接下来的 if 语句是说,如果使用移动平均,则将原始 graph 中的变量用它的移动平均值来替换(函数 replace_variable_values_with_moving_averages)。再下来的 write_graph_and_checkpoint 函数相当于将上一篇文章的训练输出文件复制到当前指定的输出路径 output_directory,最后的函数 freeze_graph_with_def_protosgraph 中的变量变成常量,然后通过函数 write_frozen_graph 和函数 write_saved_model 写出到输出路径。

        最后来解释一下函数

export_inference_graph(input_type,
                       model,
                       trained_checkpoint_prefix,
                       output_directory,
                       input_shape=None,
                       use_moving_averages=None,
                       output_collection_name='inference_op',
                       additional_output_tensor_names=None,
                       graph_hook_fn=None)

的各个参数:1.input_type,指的是输入数据的类型,exporter.py 指定了只能从以下的字典中

input_placeholder_fn_map = {
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
        _encoded_image_string_tensor_input_placeholder,
#    'tf_example': _tf_example_input_placeholder,
    }

选出其中一种,一般我们选择图像作为输入,即 image_tensor;2.model,指的是自己构建的模型,是一个类对象,如上一篇文章定义的 Model 类的一个实例:

cls_model = model.Model(is_training=False, num_classes=10)

3.trained_checkpoint_prefix,指定要导出的 .ckpt 文件路径;4.output_directory,指定导出文件的存储路径(是一个文件夹);5.input_shape,输入数据的形状,缺省时为 [None, None, None, 3];6.use_moving_average,是否使用移动平均;7.output_collection_name,输出的 collection 名,直接使用默认名,不需要修改;8.additional_output_tensor_names,指定额外的输出张量名;9.graph_hook_fn,意义不明,暂时不知道它的表示意义。

        实际调用的时候,我们一般只需要指定前四个参数,如(命名为 export_inference_graph.py):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:46:16 2018
@author: shirhe-lyh
"""

"""Tool to export a model for inference.
Outputs inference graph, asscociated checkpoint files, a frozen inference
graph and a SavedModel (https://tensorflow.github.io/serving_basic.html).
The inference graph contains one of three input nodes depending on the user
specified option.
    * 'image_tensor': Accepts a uint8 4-D tensor of shape [None, None, None, 3]
    * 'encoded_image_string_tensor': Accepts a 1-D string tensor of shape 
        [None] containg encoded PNG or JPEG images.
    * 'tf_example': Accepts a 1-D string tensor of shape [None] containing
        serialized TFExample protos.
        
and the following output nodes returned by the model.postprocess(..):
    * 'classes': Outputs float32 tensors of the form [batch_size] containing
        the classes for the predictions.
        
Example Usage:
---------------
python/python3 export_inference_graph \
    --input_type image_tensor \
    --trained_checkpoint_prefix path/to/model.ckpt \
    --output_directory path/to/exported_model_directory
    
The exported output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:
    - model.ckpt.data-00000-of-00001
    - model.ckpt.info
    - model.ckpt.meta
    - frozen_inference_graph.pb
    + saved_model (a directory)
"""
import tensorflow as tf

import exporter
import model

slim = tf.contrib.slim
flags = tf.app.flags

flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can '
                    "be one of ['image_tensor', 'encoded_image_string_tensor'"
                    ", 'tf_example']")
flags.DEFINE_string('input_shape', None, "If input_type is 'image_tensor', "
                    "this can be explicitly set the shape of this input "
                    "to a fixed size. The dimensions are to be provided as a "
                    "comma-seperated list of integers. A value of -1 can be "
                    "used for unknown dimensions. If not specified, for an "
                    "'image_tensor', the default shape will be partially "
                    "specified as '[None, None, None, 3]'.")
flags.DEFINE_string('trained_checkpoint_prefix', None,
                    'Path to trained checkpoint, typically of the form '
                    'path/to/model.ckpt')
flags.DEFINE_string('output_directory', None, 'Path to write outputs')
tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix')
tf.app.flags.mark_flag_as_required('output_directory')
FLAGS = flags.FLAGS


def main(_):
    cls_model = model.Model(is_training=False, num_classes=10)
    if FLAGS.input_shape:
        input_shape = [
            int(dim) if dim != -1 else None 
            for dim in FLAGS.input_shape.split(',')
        ]
    else:
        input_shape = [None, 28, 28, 3]
    exporter.export_inference_graph(FLAGS.input_type,
                                    cls_model,
                                    FLAGS.trained_checkpoint_prefix,
                                    FLAGS.output_directory,
                                    input_shape)
    

if __name__ == '__main__':
    tf.app.run()

在终端运行命令:

python3 export_inference_graph.py \
    --trained_checkpoint_prefix path/to/.ckpt-xxxx \
    --output_directory path/to/output/directory

很快会在 output_directory 指定的文件夹中生成一系列文件,其中的 frozen_inference_graph.pb 便是我们需要的最终用于推断的文件。至于如何读取 .pb 文件用于推断,则可以访问这个系列的文章 TensorFlow 模型保存与恢复 的第二部分。为了方便阅读,我们承接上一篇文章,使用如下代码来对训练的模型进行验证:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr  2 14:02:05 2018
@author: shirhe-lyh
"""

"""Evaluate the trained CNN model.
Example Usage:
---------------
python3 evaluate.py \
    --frozen_graph_path: Path to model frozen graph.
"""

import numpy as np
import tensorflow as tf

from captcha.image import ImageCaptcha

flags = tf.app.flags
flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.')
FLAGS = flags.FLAGS


def generate_captcha(text='1'):
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image


def main(_):
    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    
    with model_graph.as_default():
        with tf.Session(graph=model_graph) as sess:
            inputs = model_graph.get_tensor_by_name('image_tensor:0')
            classes = model_graph.get_tensor_by_name('classes:0')
            for i in range(10):
                label = np.random.randint(0, 10)
                image = generate_captcha(str(label))
                image_np = np.expand_dims(image, axis=0)
                predicted_label = sess.run(classes, 
                                           feed_dict={inputs: image_np})
                print(predicted_label, ' vs ', label)
            
            
if __name__ == '__main__':
    tf.app.run()

简单运行:

python3 evaluate.py --frozen_graph_path path/to/frozen_inference_graph.pb

可以看到验证结果。

        本文(及前文)的所有代码都在 github: slim_cnn_test,欢迎访问并下载。

预告:下一篇文章将介绍 TensorFlow 如何使用预训练文件来精调分类模型。

你可能感兴趣的:(TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式)