tf.estimator API技术手册(5)——BestExporter(最佳模型输出器)

tf.estimator API技术手册(5)——BestExporter(最佳模型输出器)

  • (一)简 介
  • (二)初始化
  • (三)方 法(Methods)
    • (1)export

(一)简 介

BestExporter继承自Exporter类,定义在tensorflow/python/estimator/exporter.py中,它用于导出最优模型的计算图和checkpoints文件,每当新的模型的表现优于旧模型时,它就会启动将最新的模型导出。创建一个BestExporter用于训练和评估:

def make_train_and_eval_fn():
  # 创建特征列
  categorial_feature_a = (
      tf.feature_column.categorical_column_with_hash_bucket(...))
  categorial_feature_a_emb = embedding_column(
      categorical_column=categorial_feature_a, ...)
  ...  # 其他特征列
  estimator = tf.estimator.DNNClassifier(
      config=tf.estimator.RunConfig(
          model_dir='/my_model', save_summary_steps=100),
      feature_columns=[categorial_feature_a_emb, ...],
      hidden_units=[1024, 512, 256])

  serving_feature_spec = tf.feature_column.make_parse_example_spec(
      categorial_feature_a_emb)
  serving_input_receiver_fn = (
      tf.estimator.export.build_parsing_serving_input_receiver_fn(
      serving_feature_spec))

  exporter = tf.estimator.BestExporter(
      name="best_exporter",
      serving_input_receiver_fn=serving_input_receiver_fn,
      exports_to_keep=5)

  train_spec = tf.estimator.TrainSpec(...)

  eval_spec = [tf.estimator.EvalSpec(
    input_fn=eval_input_fn,
    steps=100,
    exporters=exporter,
    start_delay_secs=0,
    throttle_secs=5)]

  return tf.estimator.DistributedTrainingSpec(estimator, train_spec,
                                              eval_spec)

(二)初始化

初始化方法如下:

__init__(
    name='best_exporter',
    serving_input_receiver_fn=None,
    event_file_pattern='eval/*.tfevents.*',
    compare_fn=_loss_smaller,
    assets_extra=None,
    as_text=False,
    exports_to_keep=5
)

参数如下:

  • name:
    一个Exporter的名称,将被用在导出路径中

  • serving_input_receiver_fn:
    一个无参数函数,并将返回一个ServingInputReceiver。

  • event_file_pattern:
    event file name pattern relative to model_dir. If None, however, the exporter would not be preemption-safe. To be preemption-safe, event_file_pattern should be specified.

  • compare_fn:
    一个比较函数,如果当前评估结果更优,就返回True。参照以下标准:

    • 参数:
      • best_eval_result:
        最优模型的评估结果
      • current_eval_result:
        当前候选模型的评估结果
  • assets_extra:
    一个可选的字典,用来指定如何填充导出的保存模型中的assets.extra目录,例如: {‘my_asset_file.txt’: ‘/path/to/my_asset_file.txt’}.

  • as_text:
    是否在文本格式中写入SaveModel原型,默认为False。

  • exports_to_keep:
    旧的导出结果的保存数量,旧的导出结果将会被“垃圾回收”,默认值为5,设置为None时则不进行回收操作。

(三)方 法(Methods)

(1)export

export(
    estimator,
    export_path,
    checkpoint_path,
    eval_result,
    is_the_final_export
)

参数如下:

  • estimator:
    需要导出的Estimator。

  • export_path:
    导出的路径
    A string containing a directory where to write the export.

  • checkpoint_path:
    导出checkpoints的路径

  • eval_result:
    Estimator.evaluate方法记录在checkpoint中的输出结果

  • is_the_final_export:
    当训练结束输出时,此布尔值为True。在训练中,此值为False。如果TrainSpec.max_steps为None时,把Exporter传递给tf.estimator.train_and_evaluate is_the_final_export,该值总是为False。

你可能感兴趣的:(tf.estimator,API技术手册)