TensorFlow学习笔记-tf.estimator

  • tfestimatorEstimator
    • 属性
    • 方法

tf.estimator.Estimator

Estimator class训练和测试TF模型。Estimator对象封装好通过model_fn指定的模型,给定输入和其它超参数,返回ops执行training, evaluation or prediction. 所有的输出(包含checkpoints, event files, etc.)被写入model_dir

属性

  • config
    传入 model_fn,如果 model_fn有参数named “config”
  • model_dir
  • model_fn
    The model_fn with following signature: def model_fn(features, labels, mode, config)
  • params

方法

  • __init__
__init__(
    model_fn,
    model_dir=None,
    config=None,
    params=None # 将要传入model_fn的超参数字典
)
  • evaluate

对训练模型评价

evaluate(
    input_fn, # 输入函数,返回元组features和labels
    steps=None,
    hooks=None, # List of SessionRunHook subclass instances
    checkpoint_path=None, # if none, 用model_dir中latest checkpoint
    name=None
)
  • export_savemodel
    导出inference graph作为一个SavedModel
export_savedmodel(
    export_dir_base, # 目录
    serving_input_receiver_fn, # 返回ServingInputReceiver的函数
    assets_extra=None,
    as_text=False,
    checkpoint_path=None
)
  • get_variable_names

    get_variable_names()
    返回模型中所有变量名字的列表

  • get_variable_value(name)
    根据变量name返回value

  • latest_checkpoint()
    model_dir中找到最近保存的checkpoint

  • predict
    根据给定的features产生预测

predict(
    input_fn,
    predict_keys=None,
    hooks=None,
    checkpoint_path=None
)
  • train

给定训练数据后训练model

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

你可能感兴趣的:(TensorFlow)