TensorFlow2.0 Guide官方教程 学习笔记16 -‘Training checkpoints‘

本笔记参照TensorFlow官方教程,主要是对‘Save a model-Training checkpoints’教程内容翻译和内容结构编排,原文链接:Training checkpoints

从检查点开始训练

  • 一、从tf.keras训练API保存(Saving from tf.keras training APIs)
  • 二、编写检查点(writing checkpoints)
    • 2.1 手动设置检查点(Manual checkpoint)
  • 三、加载机制(Loading mechanics)
    • 3.1 延迟修复(Delayed restorations)
    • 3.2 手动检查检查点(Manually inspecting checkpoints)
    • 3.3 序列和字典跟踪(List and dictionary tracking)
  • 四、用评估器(Estimator)保存基于对象的检查点(Saving object-based checkpoints with Estimator)
  • 五、总结


“保存TensorFlow模型”这个短语通常表示以下两种情况之一:
- 检查点(Checkpoints)
- 保存模型(SavedModel)
检查点(Checkpoints)捕获所有被模型使用过的参数(tf.Variabled对象)确切值。检查点不包含模型定义的计算的任何描述,因此通常只有在将使用保存的参数值的源代码可用时才有用。
另一方面,SavedModel格式除了参数值(检查点)之外,还包括由模型定义的计算的序列化描述。这种格式的模型独立于创建模型的源代码。因此,他们适合通过TensorFlowServing、TensorFlow Lite,TensorFlow.js进行部署,或其他编程语言的程序(C, c++, Java, Go, Rust, c#deng,TensorFlow api)。
本篇笔记介绍编写和读取检查点。

创建环境(Setup)

from __future__ import absolute_import, division, print_function, unicode_literals
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

一、从tf.keras训练API保存(Saving from tf.keras training APIs)

参考学习笔记1 Keras概览
用‘tf.keras.Model.save_weights’保存一个TensorFlow 检查点。

net.save_weights('easy_checkpoint')

二、编写检查点(writing checkpoints)

TensorFlow模型的持久状态存储在tf.Varialbe对象中。它们可以直接被构造,但通常是通过像tf.keras.layers或tf.keras.Model类的高级API创建。
管理变量的最简单方法是将它们附加到Python对象上,然后引用这些对象。
tf.train.Checkpoint,tf.keras.layers.Layer,tf.keras.Model的子类自动跟踪分配给它们属性的变量。下面的示例构造了一个简单的线性模型,然后编写包含所有模型变量值的检查点。
我们可以简单地用Model.save_weights来保存一个‘model-checkpoint’

2.1 手动设置检查点(Manual checkpoint)

  • Setup
    为了方便演示‘tf.train.Checkpoint’的特点,定义了一个玩具(toy)数据集 和优化(optimization)步骤:
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat(10).batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss
  • 创建检查点对象(Create the checkpoint objects)
    为了手动创造一个检查点,我们需要一个‘tf.train.Checkpoint’对象。我们将要设置检查点的对象将作为这个对象的属性。
    ‘tf.train.CheckpointManager’可以用来管理多个检查点。
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
  • 训练模型和设置模型检查点(Train and checkpoint the model)
    接下来的训练循环会创建模型实例和优化器实例,然后将他们一起放到‘tf.train.Checkpoint’对象中。在每个数据批上它将循环调用训练步,周期性地往硬盘写检查点。
def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for example in toy_dataset():
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.96
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.37
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 16.81
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.39
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.83
  • 恢复和继续训练(Restore and continue training)
    在第一次设置检查点后,我们可以用一个新的模型和管理器从我们离开的地方开始训练:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 2.76
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.02
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.76
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 1.11
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.33

‘tf.train.CheckpointManager’对象会删除旧的检查点。仅保持最多3个最近的检查点。

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

这些路径(如’./tf_ckpts/ckpt-10’),不是硬盘上的文件,而是一个索引文件和一个或多个包含变量值的数据文件的前缀。这些前缀一起组合在一个单独的‘checkpoint’文件(’./tf_ckpts/checkpoint’),这个文件里‘CheckpointManager’保持的状态。

!ls ./tf_ckpts
checkpoint		     ckpt-8.data-00001-of-00002
ckpt-10.data-00000-of-00002  ckpt-8.index
ckpt-10.data-00001-of-00002  ckpt-9.data-00000-of-00002
ckpt-10.index		     ckpt-9.data-00001-of-00002
ckpt-8.data-00000-of-00002   ckpt-9.index

三、加载机制(Loading mechanics)

TensorFlow通过遍历带命名边(named edges)的有向图来将变量与检查点值匹配,从加载的对象开始 。边界名(Edge name)通常取自对象名称的属性,例如‘l1’ 来源于‘self.l1 = tf.keras.layers.Dense(5)’。‘tf.train.Checkpoint’使用它的关键字参数名,如tf.train.Checkpoint(step=…)中的“step”。
上面例子中的依赖关系图是这样的:
TensorFlow2.0 Guide官方教程 学习笔记16 -‘Training checkpoints‘_第1张图片
优化器是红色的,常规变量是蓝色的,优化器槽变量(slot variables)是橙色的。其他节点,例如表示tf.train.Checkpoint,是黑色的。

槽变量(slot variables)是优化器状态的一部分,但是是为特定的变量创建的。例如,上面的“m”边对应于动量,Adam优化器为每个变量跟踪动量。只有在变量和优化器都被保存的情况下,槽变量才会被保存在检查点中,从而保存虚线边缘(dashed edges)。

在tf.train.Checkpoint对象上调用‘restore()’对请求的恢复进行排队,一旦有来自检查点对象的匹配路径,就立即恢复变量值。例如,我们可以通过通过网络和层重构一个到它的路径来从上面定义的模型中加载内核。

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
[0. 0. 0. 0. 0.]
[2.5321825 2.0783062 2.4567614 4.824098  5.1221457]

这些新对象的依赖关系图是我们在上面编写的较大检查点的更小的子图。它只包含了‘tf.train.Checkpoint’用来计算检查点数目的偏差和计数器。
TensorFlow2.0 Guide官方教程 学习笔记16 -‘Training checkpoints‘_第2张图片
‘restore()’返回一个状态对象,该对象有一个可选的断言(optional assertions)。所有我们在检查点里创建的对象都被恢复了,所以‘status.assert_existing_objects_matched()’通过。

status.assert_existing_objects_matched()

检查点中有许多没有匹配的对象,包括层的内核和优化器的变量。‘status.assert_consumed’只有在检查点和程序完全匹配时才会通过,并在这里抛出异常。

3.1 延迟修复(Delayed restorations)

当输入形状(input shapes)可用时,TensorFlow中的层对象可能会在第一次调用是延迟变量的创建。例如,稠密层内核的形状取决于该层的输入和输出形状,因此作为构造函数参数所需的输出形状不足以单独创建变量。因为调用一个层也会读取变量的值,所以必须在变量的创建和第一次使用之间进行恢复。
为了支持这个习惯,tf.train.Checkpoint将恢复尚未具有匹配变量进行排队

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6374264 4.8115244 4.9366684 4.769622  4.8615403]]

3.2 手动检查检查点(Manually inspecting checkpoints)

tf.train.list_variables列出检查点键(keys)和检查点中变量的形状。检查点键是上面显示的图中的路径。

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

3.3 序列和字典跟踪(List and dictionary tracking)

像self.l1 = tf.keras.layers.Dense(5)直接给属性赋值一样,我们可以将列表和字典分配给用来跟踪它们内容的属性。

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

此刻,我们可能会注意到列表和字典的包装器对象。这些包装器是底层数据结构的可检查点版本。就像基于属性的加载一样,这些包装器在将变量添加到容器中时立即恢复该变量的值。

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

对tf.keras.Model的子类自动应用相同的跟踪,并可用于跟踪层的列表。

四、用评估器(Estimator)保存基于对象的检查点(Saving object-based checkpoints with Estimator)

详情参考:estimator
默认情况下,评估器(estimator)使用变量名而不是前面几节中描述的对象图来保存检查点。tf.train.Checkpoint将接受基于名称的检查点,但是当将模型的一部分移到Estimator的model_fn之外时,变量名可能会改变。保存基于对象的检查点使得在评估器内部训练模型,然后在外部使用模型变得更加容易。

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tensorflow-2.0.0/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tensorflow-2.0.0/python3.6/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:loss = 4.6649957, step = 0
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Loss for final step: 39.683983.

然后,‘tf.train.Checkpoint’可以从‘model_dir’中加载Estimator检查点中了

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

五、总结

TensorFlow对象提供了一种简单的自动机制来保存和恢复它们所使用的变量的值。

你可能感兴趣的:(TensorFlow,2.0,学习笔记,TensorFlow2.0)