《Estimator工程实现》系列二:使用Estimator时fine-tune之scaffold

代码使用环境: tensorflow r1.12

1. tf.estimator.EstimatorSpec API回顾

tf.estimator.EstimatorSpec
接口: r.12, r.1.13及 r2.0相同

@staticmethod
__new__(
    cls,
    mode,
    predictions=None,
    loss=None,
    train_op=None,
    eval_metric_ops=None,
    export_outputs=None,
    training_chief_hooks=None,
    training_hooks=None,
    scaffold=None,
    evaluation_hooks=None,
    prediction_hooks=None
)

1.1官方API描述

  • Args:
    • mode: A ModeKeys. Specifies if this is training, evaluation or prediction.
    • predictions: Predictions Tensor or dict of Tensor.
    • loss: Training loss Tensor. Must be either scalar, or with shape [1].
    • train_op: Op for the training step.
    • eval_metric_ops: Dict of metric results keyed by name. The values of the dict can be one of the following: (1) instance of Metric class. (2) Results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.
    • export_outputs: Describes the output signatures to be exported to SavedModel and used during serving. A dict {name: output} where:
      • name: An arbitrary name for this output.
      • output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. If no entry is provided, a default PredictOutput mapping to predictions will be created.
    • training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.
    • training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.
    • scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.
    • evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.
    • prediction_hooks: Iterable of tf.train.SessionRunHook objects to run during predictions.

在这里,我们可以观察到 training_hooks, scaffold, evaluation_hooks, prediction_hooks 这四项。 hook

1.2 Scaffold类官方描述

Structure to create or gather pieces commonly needed to train a model.

When you build a model for training you usually need ops to initialize
variables, a Saver to checkpoint them, an op to collect summaries for
the visualizer, and so on.
Various libraries built on top of the core TensorFlow library take care of
creating some or all of these pieces and storing them in well known
collections in the graph. The Scaffold class helps pick these pieces from
the graph collections, creating and adding them to the collections if needed.
If you call the scaffold constructor without any arguments, it will pick
pieces from the collections, creating default ones if needed when
scaffold.finalize() is called. You can pass arguments to the constructor to
provide your own pieces. Pieces that you pass to the constructor are not
added to the graph collections.

也就是说,我们可以通过scaffold 可以对 saver, variables summary等进行操作,与hook类似。

2.通过scaffold 实现加载参数范例

我们只需要实现一个inin_fn 并传入scaffold 即可。

def get_init_fn_for_scaffold(checkpoint_path, model_dir, checkpoint_exclude_scopes, ignore_missing_vars, use_v1=False):
    flags_checkpoint_path = checkpoint_path
    # Warn the user if a checkpoint exists in the model_dir. Then ignore.
    if tf.train.latest_checkpoint(model_dir):
        tf.logging.info('Ignoring --checkpoint_path because a checkpoint already exists in %s' % model_dir)
        return None
    if flags_checkpoint_path is None:
        return None
    exclusions = []
    if checkpoint_exclude_scopes:
        exclusions = [scope.strip() for scope in checkpoint_exclude_scopes.split(',')]
    variables_to_restore = []
    for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        excluded = False
        #print(var.op.name)
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)
  
    if tf.gfile.IsDirectory(flags_checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(flags_checkpoint_path)
    else:
        checkpoint_path = flags_checkpoint_path

    tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, ignore_missing_vars))

    if not variables_to_restore:
        raise ValueError('variables_to_restore cannot be empty')
    if ignore_missing_vars:
        reader = tf.train.NewCheckpointReader(checkpoint_path)
        if isinstance(variables_to_restore, dict):
            var_dict = variables_to_restore
        else:
            var_dict = {var.op.name: var for var in variables_to_restore}
        available_vars = {}
        for var in var_dict:
            if reader.has_tensor(var):
                available_vars[var] = var_dict[var]
            else:
                tf.logging.warning('Variable %s missing in checkpoint %s', var, checkpoint_path)
        variables_to_restore = available_vars
    if variables_to_restore:
        saver = tf.train.Saver(variables_to_restore, reshape=False, write_version=tf.train.SaverDef.V1 if use_v1 else tf.train.SaverDef.V2)
        saver.build()
        def callback(scaffold, session):
            saver.restore(session, checkpoint_path)
        return callback
    else:
        tf.logging.warning('No Variables to restore')
        return None

使用时请模仿上面的init_fn实例化一个scaffold,并传入EstimatorSpec即可。

# setup fine tune scaffold
scaffold  = tf.train.Scaffold(init_fn=get_init_fn_for_scaffold(params['checkpoint_path'], params['model_dir'], 
params['checkpoint_exclude_scopes'], params['ignore_missing_vars']))

# create estimator training spec
return tf.estimator.EstimatorSpec(
                          mode=mode,
                          predictions=predictions,
                          loss=loss,
                          train_op=train_op,
                          eval_metric_ops=metrics,
                          scaffold=scaffold )

参考文档:

  1. tensorflow 官方文档: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/estimator/EstimatorSpec
    https://www.tensorflow.org/api_docs/python/tf/train/Scaffold
    https://www.tensorflow.org/api_docs/python/tf/train/SessionRunHook
  2. 参考
    https://www.jianshu.com/p/1df991a4b815
  3. tf.fashionAI github项目
    https://github.com/HiKapok/tf.fashionAI

你可能感兴趣的:(《Estimator工程实现》系列二:使用Estimator时fine-tune之scaffold)