Estimator是与模型协同工作的高级工具。用于封装通过model_dn定义的模型,该模型给定了输入与一些其他的参数。返回完成training, evaluation, or predictions的操作。
__init__(
model_fn,
model_dir=None,
config=None,
params=None,
warm_start_from=None
)
model_fn: Model function
model_dir: Directory to save model parameters, graph and etc.
重要函数:
evaluate: 评估模型
evaluate(
input_fn,
steps=None,
hooks=None,
checkpoint_path=None,
name=None
)
predict: 预测模型
predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None,
yield_single_examples=True
)
train: 训练模型
train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)
**steps:**训练步骤
Evaluates the model given evaluation data input_fn.
返回一个input_fn,实现将numpy数组作为输入送入模型。
tf.estimator.inputs.numpy_input_fn(
x,
y=None,
batch_size=128,
num_epochs=1,
shuffle=None,
queue_capacity=1000,
num_threads=1
)
This returns a function outputting features and targets based on the dict of numpy arrays. The dict features has the same keys as the x. The dict targets has the same keys as the y if y is a dict.【注意:y可能是字典类型,此时与x一定有相同的关键词才能进行匹配】
实例
age = np.arange(4) * 1.0
height = np.arange(32, 36)
x = {'age': age, 'height': height}
y = np.arange(-32, -28)
with tf.Session() as session:
input_fn = numpy_io.numpy_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
对于模型有三种模式:训练;评估;预测。这三种模式通过三个关键字【类的成员】存储在ModelKeys类中。
TRAIN: training mode.
EVAL: evaluation mode.
PREDICT: inference mode.
tf.estimator.ModeKeys.PREDICT:表示处于预测模式
将来自model_fn的操作方式和操作对象传入Estimator。EstimatorSpec用以定义使用Estimator运行的model。
@staticmethod
__new__(
cls,
mode,
predictions=None,
loss=None,
train_op=None,
eval_metric_ops=None,
export_outputs=None,
training_chief_hooks=None,
training_hooks=None,
scaffold=None,
evaluation_hooks=None,
prediction_hooks=None
)
参数:
mode: A ModeKeys. Specifies if this is training, evaluation or prediction.
predictions: Predictions Tensor or dict of Tensor.
loss: Training loss Tensor. Must be either scalar, or with shape [1].
train_op: Op for the training step.
eval_metric_ops: Dict of metric results keyed by name. (1) instance of Metric class. (2) Results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.).