TensorFlow2.0 Guide官方教程 学习笔记9- Estimators

本笔记参照TensorFlow官方教程,主要是对‘Estimators’教程内容翻译和内容结构编排,原文链接:Estimators
目录
一、优势
二、评估器的能力
三、预制评估器(Pre-made Estimators)
3.1预制评估器程序结构
3.2预制评估器的好处
四、定制评估器
五、推荐工作流
六、从Keras模型中创建一个评估器


本篇笔记介绍高级TensorFlow API——tf.estimator。评估器将以下动作封装在一起:

  • 训练(training)
  • 评价(evaluating)
  • 预测(prediction)
  • 出口服务(export for serving)
    我们既可以用keras里预制好的评估器,也可以自己定制评估器。所有预制或者定制的评估器都是基于类tf.estimator.Estimator的类。

一、优势

与tf.keras.Model相同,一个评估器是一个模型级别的抽象。tf.estimator提供一些tf.keras尚在开发的能力,比如:

  • 基于参数服务器的训练
  • 全TFX整合

二、评估器能力

评估器提供以下好处:

  • 您可以在本地主机或分布式多服务器环境上运行基于评估器的模型,而无需更改模型。此外,您可以在cpu、gpu或TPUs上运行基于估计器的模型,而无需重新编写模型。
  • 评估器提供了一个安全的分布式训练循环,可以控制如何以及何时进行
    -加载数据
    -处理异常
    -创建检查点文件和从失败中恢复
    -保存TensorBoard总结
    在用评估器编写应用时,我们必须将数据输入管道与模型分开。这种分离简化了不同数据集的实验。

三、预制评估器

预构建的估计器使您能够在比基本TensorFlow api更高的概念级别上工作。您不再需要担心创建计算图形或会话,因为评估程序将为您处理所有的“管道”。此外,预构建的评估器允许您通过仅进行最小的代码更改来试验不同的模型体系结构。例如:tf.estimator.DNNClassifier,是一个预制的评估器类,用来训练基于密集前馈神经网络的分类模型。
3.1预制评估器程序的结构
一个基于预制评估器的TensorFlow程序通常包括以下四个步骤:
(1)编写一个或多个数据集导入函数:例如我们可以创建一个函数来导入训练集,另一个函数来导入测试集。每个数据集导入函数必须返回两个对象:

  • 一个字典,键是特征名称,值是包含相应特征数据的张量(或稀疏张量)
  • 一个张量,包含一个或多个标签
    例如,下面的代码演示输入函数的基本框架:
def input_fn(dataset):
    ...  # manipulate dataset, extracting the feature dict and the label
    return feature_dict, label

(2)定义特征列:每一个tf.feature_column标识特性名称、类型和任何输入预处理。例如,下面的代码段创建了三个包含整数或浮点数据的特性列。前两个特性列只是标识特性的名称和类型。第三个特性列还指定了一个lambda程序将调用来缩放原始数据:

# Define three numeric feature columns.
population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column(
  'median_education',
  normalizer_fn=lambda x: x - global_education_mean)

(3)实例化相关预制评估器
例如:下面是一个名为LinearClassifier的预制评估器的示例示例:

# Instantiate an estimator, passing the feature columns.
estimator = tf.estimator.LinearClassifier(
  feature_columns=[population, crime_rate, median_education])

(4)调用一个训练,评估,或推断方法
例如:所有评估器提供了‘训练’方法,可以用来训练模型:

# `input_fn` is the function created in Step 1
estimator.train(input_fn=my_training_set, steps=2000)

3.2预制评估器的好处
预制评估器把最佳实践编码,提供以下好处:
(1)确定计算图的不同部分应该在何处运行的最佳实践,在单机或集群上实现策略。
(2)事件(摘要)写作和普遍有用摘要的最佳实践。

	注意:如果不使用预制评估器,我们必须自己实现前面说的特性。

四、定制评估器

每一个评估器的核心——无论是预先制作的还是定制的——都是它的模型函数,它是一种为训练、评估和预测构建图形的方法。当您使用预做的评估器时,其他人已经实现了模型函数。依赖自定义评估器时,必须自己编写模型函数。

五、推荐工作流

  1. 假设存在一个适当的预制评估器,使用它来构建我们的第一个模型,并使用它的结果来建立一个基线
  2. 构建和测试整个管道,包括这个预制评估器数据的完整性和可靠性
  3. 如果有合适地可选预制评估器可用,运行试验来决定哪个预制评估器产生最后结果
  4. 可能的话,通过构建我们自己的自定义评估器来进一步提升我们的模型。
from __future__ import absolute_import, division, print_function, unicode_literals
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

六、从Keras模型中创建一个评估器

我们可以用tf.keras.estimator.model_to_estimator将keras模型转化成评估器。这样我们的keras模型可以拥有评估器的能力,比如分布训练。
下面我们实例化一个Keras MobileNet V2模型,并使用优化器、代价和要培训的指标来编译模型:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(1, activation='softmax')
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy'])
Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 2s 0us/step

从编译的Keras模型中创建一个评估器。Keras模型的初始化状态将在创建的评估器里保留:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4qcdq5pd
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4qcdq5pd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

对待派生的评估器就像对待任何其他评估器一样。

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

调用评估器训练函数来训练:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)
WARNING:tensorflow:From /tensorflow-2.0.0/python3.6/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading and preparing dataset cats_vs_dogs (786.68 MiB) to /root/tensorflow_datasets/cats_vs_dogs/2.0.1...
/usr/local/lib/python3.6/dist-packages/urllib3/connectionpool.py:847: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
WARNING:absl:1738 images were corrupted and were skipped
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
WARNING:absl:Warning: Setting shuffle_files=True because split=TRAIN and shuffle_files=None. This behavior will be deprecated on 2019-08-06, at which point shuffle_files=False will be the default for all splits.
Dataset cats_vs_dogs downloaded and prepared to /root/tensorflow_datasets/cats_vs_dogs/2.0.1. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpoymsrf37/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpoymsrf37/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmpoymsrf37/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmpoymsrf37/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpoymsrf37/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpoymsrf37/model.ckpt.
INFO:tensorflow:loss = 9.104111, step = 0
INFO:tensorflow:loss = 9.104111, step = 0
INFO:tensorflow:global_step/sec: 3.66157
INFO:tensorflow:global_step/sec: 3.66157
INFO:tensorflow:loss = 8.145783, step = 100 (27.319 sec)
INFO:tensorflow:loss = 8.145783, step = 100 (27.319 sec)
INFO:tensorflow:global_step/sec: 3.8213
INFO:tensorflow:global_step/sec: 3.8213
INFO:tensorflow:loss = 6.2291284, step = 200 (26.168 sec)
INFO:tensorflow:loss = 6.2291284, step = 200 (26.168 sec)
INFO:tensorflow:global_step/sec: 3.81693
INFO:tensorflow:global_step/sec: 3.81693
INFO:tensorflow:loss = 6.708292, step = 300 (26.195 sec)
INFO:tensorflow:loss = 6.708292, step = 300 (26.195 sec)
INFO:tensorflow:global_step/sec: 3.78591
INFO:tensorflow:global_step/sec: 3.78591
INFO:tensorflow:loss = 6.708292, step = 400 (26.419 sec)
INFO:tensorflow:loss = 6.708292, step = 400 (26.419 sec)
INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmpoymsrf37/model.ckpt.
INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmpoymsrf37/model.ckpt.
INFO:tensorflow:Loss for final step: 7.187456.
INFO:tensorflow:Loss for final step: 7.187456.

同样地,调用评估器评价函数来进行评价:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
WARNING:absl:Warning: Setting shuffle_files=True because split=TRAIN and shuffle_files=None. This behavior will be deprecated on 2019-08-06, at which point shuffle_files=False will be the default for all splits.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-10-19T05:41:13Z
INFO:tensorflow:Starting evaluation at 2019-10-19T05:41:13Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpoymsrf37/model.ckpt-500
INFO:tensorflow:Restoring parameters from /tmp/tmpoymsrf37/model.ckpt-500
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Finished evaluation at 2019-10-19-05:41:18
INFO:tensorflow:Finished evaluation at 2019-10-19-05:41:18
INFO:tensorflow:Saving dict for global step 500: accuracy = 0.490625, global_step = 500, loss = 7.8103685
INFO:tensorflow:Saving dict for global step 500: accuracy = 0.490625, global_step = 500, loss = 7.8103685
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmpoymsrf37/model.ckpt-500
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmpoymsrf37/model.ckpt-500
{'accuracy': 0.490625, 'global_step': 500, 'loss': 7.8103685}

更多信息,请参考tf.keras.estimator.model_to_estimator文档

你可能感兴趣的:(TensorFlow学习笔记)