MonitoredTrainingSession
中使用 HookEstimator
中使用 Hookslim
中使用 HookSessionRunHook
用来扩展那些将 session
封装起来的高级 API 的 session.run
的行为。
SessionRunHook
对于追踪训练过程、报告进度、实现提前停止等非常有用。
SessionRunHook
以观察者模式运行。SessionRunHook
的设计中有几个非常重要的时间点:
session
使用前session.run()
调用之前session.run()
调用之后session
关闭前SessionRunHook
封装了一些可重用、可组合的计算,并且可以顺便完成 session.run()
的调用。利用 Hook,我们可以为 run()
调用添加任何的 ops或tensor/feeds;并且在 run()
调用完成后获得请求的输出。Hook 可以利用 hook.begin()
方法向图中添加 ops,但请注意:在 begin()
方法被调用后,计算图就 finalized 了。
TensorFlow 中已经内置了一些 Hook:
StopAtStepHook
:根据 global_step 来停止训练。CheckpointSaverHook
:保存 checkpoint。LoggingTensorHook
:以日志的形式输出一个或多个 tensor 的值。NanTensorHook
:如果给定的 Tensor
包含 Nan,就停止训练。SummarySaverHook
:保存 summaries 到一个 summary writer。上节,我们已经介绍了预制 Hook,使用其可以实现一些常见功能。如果这些 Hook 不能满足你的需求,那么自定义 Hook 是比较好的选择。
下面是自定义 Hook 的编写模板:
class ExampleHook(tf.train.SessionRunHook):
def begin(self):
# You can add ops to the graph here.
print('Starting the session.')
self.your_tensor = ...
def after_create_session(self, session, coord):
# When this is called, the graph is finalized and
# ops can no longer be added to the graph.
print('Session created.')
def before_run(self, run_context):
print('Before calling session.run().')
return SessionRunArgs(self.your_tensor)
def after_run(self, run_context, run_values): # run_values 为 sess.run 的结果
print('Done running one step. The value of my tensor: %s',
run_values.results)
if you-need-to-stop-loop:
run_context.request_stop()
def end(self, session):
print('Done with the session.')
上面是官方给的解释,下面是我设计的一个设置学习速率的Hook:
class _LearningRateSetterHook(tf.train.SessionRunHook):
"""Sets learning_rate based on global step."""
def begin(self):
self._global_step_tensor = tf.train.get_or_create_global_step()
self._lrn_rate_tensor = tf.get_default_graph().get_tensor_by_name('learning_rate:0') # 注意,这里根据name来索引tensor,所以请在定义学习速率的时候,为op添加名字
self._lrn_rate = 0.1 # 第一阶段的学习速率
def before_run(self, run_context):
return tf.train.SessionRunArgs(
self._global_step_tensor, # Asks for global step value.
feed_dict={self._lrn_rate_tensor: self._lrn_rate}) # Sets learning rate
def after_run(self, run_context, run_values):
train_step = run_values.results
if train_step < 10000:
pass
elif train_step < 20000:
self._lrn_rate = 0.01 # 第二阶段的学习速率
elif train_step < 30000:
self._lrn_rate = 0.001 # 第三阶段的学习速率
else:
self._lrn_rate = 0.0001 # 第四阶段的学习速率
在那些将 session
封装起来的高阶 API 中,我们可以使用 Hook 来扩展这些这些 API 的 session.run()
的行为。
首先,我们梳理一下将 session
封装起来的高阶 API 有哪些?这些 API 包括,但不限于:
tf.train.MonitoredTrainingSession
:tf.estimator.Estimator
:tf.contrib.slim
:MonitoredTrainingSession
中使用 Hookwith tf.train.MonitoredTrainingSession(hooks=your_hooks, ...) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(your_fetches)
Estimator
中使用 Hook在 tf.estimator.Estimator
的 train
、evaluate
、predict
方法中都可以使用 Hook。
下面是这些方法的 API:
# 训练
# 这里的 est 是一个 Estimator 实例
est.train(input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None)
# 评估
est.evaluate(input_fn,
steps=None,
hooks=None,
checkpoint_path=None,
name=None)
# 预测
est.predict(input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None,
yield_single_examples=True)
slim
中使用 HookSlim 是 TensorFlow 中一个非常优秀的高阶 API,其可以极大地简化模型的构建、训练、评估。
未完待续。。。。
通过自定义 Hook 的过程,我们了解到一个 Hook 包括 begin
、after_create_session
、before_run
、after_run
、end
五个方法。
下面的伪代码演示了 Hook 的运行过程:
# 伪代码
call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested:
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
注意:如果 sess.run()
引发 OutOfRangeError
、StopIteration
或其它异常,那么 hooks.after_run()
和 hooks.end()
将不会被执行。
预制的 Hook 比较多,这里我们以 tf.train.StopAtStepHook
为例,来看看内置 Hook 是怎么编写的。
# tf.train.StopAtStepHook 的定义
class StopAtStepHook(tf.train.SessionRunHook):
"""Hook that requests stop at a specified step."""
def __init__(self, num_steps=None, last_step=None):
"""Initializes a `StopAtStepHook`.
This hook requests stop after either a number of steps have been
executed or a last step has been reached. Only one of the two options can be
specified.
if `num_steps` is specified, it indicates the number of steps to execute
after `begin()` is called. If instead `last_step` is specified, it
indicates the last step we want to execute, as passed to the `after_run()`
call.
Args:
num_steps: Number of steps to execute.
last_step: Step after which to stop.
Raises:
ValueError: If one of the arguments is invalid.
"""
if num_steps is None and last_step is None:
raise ValueError("One of num_steps or last_step must be specified.")
if num_steps is not None and last_step is not None:
raise ValueError("Only one of num_steps or last_step can be specified.")
self._num_steps = num_steps
self._last_step = last_step
def begin(self):
self._global_step_tensor = tf.train.get_or_create_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use StopAtStepHook.")
def after_create_session(self, session, coord):
if self._last_step is None:
global_step = session.run(self._global_step_tensor)
self._last_step = global_step + self._num_steps
def before_run(self, run_context): # pylint: disable=unused-argument
return tf.train.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
global_step = run_values.results + 1
if global_step >= self._last_step:
# Check latest global step to ensure that the targeted last step is
# reached. global_step read tensor is the value of global step
# before running the operation. We're not sure whether current session.run
# incremented the global_step or not. Here we're checking it.
step = run_context.session.run(self._global_step_tensor)
if step >= self._last_step:
run_context.request_stop()
SessionRunHook
源码:linktf.train.SessionRunHook()
类详解:link注意:欢迎大家转载,但需注明出处哦
\quad \quad    \; https://blog.csdn.net/u014061630/article/details/82998116