INFO:tensorflow:Create CheckpointSaverHook.
2018-01-15 16:24:33.513942: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
2018-01-15 16:24:34.390763: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Found device 0 with properties:
name: GeForce GTX 1080 Ti major: 6 minor: 1 memoryClockRate(GHz): 1.582
pciBusID: 0000:89:00.0
totalMemory: 10.91GiB freeMemory: 10.75GiB
2018-01-15 16:24:34.390813: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:89:00.0, compute capability: 6.1)
2018-01-15 16:25:58.010092: I tensorflow/core/kernels/shuffle_dataset_op.cc:110] Filling up shuffle buffer (this may take a while): 499 of 1000
2018-01-15 16:26:07.689469: I tensorflow/core/kernels/shuffle_dataset_op.cc:121] Shuffle buffer filled.
INFO:tensorflow:Saving checkpoints for 1 into /train/mymodels/model.ckpt.
INFO:tensorflow:loss = 22.2663, step = 1
......
EBUG:tensorflow:Skipping evaluation due to same checkpoint /train/mymodels/model.ckpt-1 for step 100 as for step 50.
执行流程如下:
experiment.train_and_evaluate()
# 验证部分用hook实现,
if self._min_eval_frequency:
self._train_monitors += [
monitors.ValidationMonitor(
input_fn=self._eval_input_fn,
eval_steps=self._eval_steps,
metrics=self._eval_metrics,
every_n_steps=self._min_eval_frequency,
name=eval_dir_suffix,
hooks=self._eval_hooks)
]
# 训练部分最终调用estimator._train_model(), 第一次训练会保存一下快照!!!
self.train(delay_secs=0)
experiment.train(delay_secs=0) -> experiment._estimator.train-> estimator._train_model()
#estimator._train_model()代码
# ...
# 1. 增加loss监控 (通过hooks)
# Check if the user created a loss summary, and add one if they didn't.
# We assume here that the summary is called 'loss'. If it is not, we will
# make another one with the name 'loss' to ensure it shows up in the right
# graph in TensorBoard.
if not any([x.op.name == 'loss'
for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
summary.scalar('loss', estimator_spec.loss)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
worker_hooks.extend(hooks)
worker_hooks.extend([
training.NanTensorHook(estimator_spec.loss),
training.LoggingTensorHook(
{
'loss': estimator_spec.loss,
'step': global_step_tensor
},
every_n_iter=100)
])
worker_hooks.extend(estimator_spec.training_hooks)
# 2. 创建saver 如果没有提供saver则创建
if not (estimator_spec.scaffold.saver or
ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(
ops.GraphKeys.SAVERS,
training.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
keep_checkpoint_every_n_hours=(
self._config.keep_checkpoint_every_n_hours),
defer_build=True,
save_relative_paths=True))
chief_hooks = []
all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
saver_hooks = [
h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
if not saver_hooks:
# 3. checkpoint saver hooks 这是checkpoint保存的关键点
chief_hooks = [
training.CheckpointSaverHook(
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=estimator_spec.scaffold)
]
saver_hooks = [chief_hooks[0]]
CheckpointSaverHook
class CheckpointSaverHook(session_run_hook.SessionRunHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
checkpoint_dir,
save_secs=None,
save_steps=None,
saver=None,
checkpoint_basename="model.ckpt",
scaffold=None,
listeners=None):
"""Initializes a `CheckpointSaverHook`.
Args:
checkpoint_dir: `str`, base directory for the checkpoint files.
save_secs: `int`, save every N secs.
save_steps: `int`, save every N steps.
saver: `Saver` object, used for saving.
checkpoint_basename: `str`, base name for the checkpoint files.
scaffold: `Scaffold`, use to get saver object.
listeners: List of `CheckpointSaverListener` subclass instances.
Used for callbacks that run immediately before or after this hook saves
the checkpoint.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of saver or scaffold should be set.
"""
logging.info("Create CheckpointSaverHook.")
if saver is not None and scaffold is not None:
raise ValueError("You cannot provide both saver and scaffold.")
self._saver = saver
self._checkpoint_dir = checkpoint_dir
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
self._scaffold = scaffold
self._timer = SecondOrStepTimer(every_secs=save_secs,
every_steps=save_steps)
self._listeners = listeners or []
def begin(self):
self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use CheckpointSaverHook.")
for l in self._listeners:
l.begin()
def before_run(self, run_context): # pylint: disable=unused-argument
if self._timer.last_triggered_step() is None:
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
# add variables in begin. Graph is finalized after all begin calls.
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir,
"graph.pbtxt")
saver_def = self._get_saver().saver_def if self._get_saver() else None
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True),
saver_def=saver_def)
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
return SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
stale_global_step = run_values.results
#这个函数很关键!!!!! 当 “第一次执行” 或者 “到了该执行checkpoint的时候” 它都会返回true
if self._timer.should_trigger_for_step(stale_global_step+1):
# get the real value after train op.
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
self._save(run_context.session, global_step)
def end(self, session):
last_step = session.run(self._global_step_tensor)
if last_step != self._timer.last_triggered_step():
self._save(session, last_step)
for l in self._listeners:
l.end(session, last_step)
def _save(self, session, step):
"""Saves the latest checkpoint."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
for l in self._listeners:
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
self._summary_writer.add_session_log(
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
for l in self._listeners:
l.after_save(session, step)
def _get_saver(self):
if self._saver is not None:
return self._saver
elif self._scaffold is not None:
return self._scaffold.saver
# Get saver from the SAVERS collection if present.
collection_key = ops.GraphKeys.SAVERS
savers = ops.get_collection(collection_key)
if not savers:
raise RuntimeError(
"No items in collection {}. Please add a saver to the collection "
"or provide a saver or scaffold.".format(collection_key))
elif len(savers) > 1:
raise RuntimeError(
"More than one item in collection {}. "
"Please indicate which one to use by passing it to the constructor.".
format(collection_key))
self._saver = savers[0]
return savers[0]
SecondOrStepTimer.should_trigger_for_step
class SecondOrStepTimer(_HookTimer):
"""Timer that triggers at most once every N seconds or once every N steps.
"""
def __init__(self, every_secs=None, every_steps=None):
self.reset()
self._every_secs = every_secs
self._every_steps = every_steps
if self._every_secs is None and self._every_steps is None:
raise ValueError("Either every_secs or every_steps should be provided.")
if (self._every_secs is not None) and (self._every_steps is not None):
raise ValueError("Can not provide both every_secs and every_steps.")
super(SecondOrStepTimer, self).__init__()
def reset(self):
self._last_triggered_step = None
self._last_triggered_time = None
def should_trigger_for_step(self, step):
"""Return true if the timer should trigger for the specified step.
Args:
step: Training step to trigger on.
Returns:
True if the difference between the current time and the time of the last
trigger exceeds `every_secs`, or if the difference between the current
step and the last triggered step exceeds `every_steps`. False otherwise.
"""
# 如果是第一次执行
if self._last_triggered_step is None:
return True
if self._last_triggered_step == step:
return False
if self._every_secs is not None:
if time.time() >= self._last_triggered_time + self._every_secs:
return True
if self._every_steps is not None:
if step >= self._last_triggered_step + self._every_steps:
return True
return False
def update_last_triggered_step(self, step):
current_time = time.time()
if self._last_triggered_time is None:
elapsed_secs = None
elapsed_steps = None
else:
elapsed_secs = current_time - self._last_triggered_time
elapsed_steps = step - self._last_triggered_step
self._last_triggered_time = current_time
self._last_triggered_step = step
return (elapsed_secs, elapsed_steps)
def last_triggered_step(self):
return self._last_triggered_step
# experiment.train_and_evaluate()
self._train_monitors += [
monitors.ValidationMonitor(
input_fn=self._eval_input_fn,
eval_steps=self._eval_steps,
metrics=self._eval_metrics,
every_n_steps=self._min_eval_frequency,
name=eval_dir_suffix,
hooks=self._eval_hooks)
]
class ValidationMonitor(EveryN):
"""Runs evaluation of a given estimator, at most every N steps.
Note that the evaluation is done based on the saved checkpoint, which will
usually be older than the current step.
Can do early stopping on validation metrics if `early_stopping_rounds` is
provided.
"""
# ...... 略
def every_n_step_end(self, step, outputs):
super(ValidationMonitor, self).every_n_step_end(step, outputs)
# Check that we are not running evaluation on the same checkpoint.
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
if latest_path is None:
logging.debug("Skipping evaluation since model has not been saved yet "
"at step %d.", step)
return False
if latest_path is not None and latest_path == self._latest_path:
# 防止重复!!!!
logging.debug("Skipping evaluation due to same checkpoint %s for step %d "
"as for step %d.", latest_path, step,
self._latest_path_step)
return False
self._latest_path = latest_path
self._latest_path_step = step