tf.train.Supervisor
已经被弃用了官方建议使tf.train.MonitoredTrainingSession()
作为代替。
tf.supervisor
Supervisor
可以自动的帮我们做一些事情比如:
- 自动去
checkpoint
加载数据(如果有checkpoint
那么就加载这个checkpoint
,如果没有的话那么就从0
开始训练)。 - 自动进行全局变量的初始化。
- 自身有一个
saver
,用于保存checkpoint
。 - 自身就有一个
summary_computed
用来保存summary
,不需要我们自己写入。
import tensorflow as tf
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
tf.scalar_summary("a",a)
init_op = tf.global_variables_initializer()
merged_summary_op = tf.merge_all_summaries()
sv = tf.train.Supervisor(logdir="/home/keith/tmp/",init_op=init_op) #logdir用来保存checkpoint和summary
saver=sv.saver #创建saver
with sv.managed_session() as sess: #会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
for i in xrange(1000):
update_ = sess.run(update)
print update_
if i % 10 == 0:
merged_summary = sess.run(merged_summary_op)
sv.summary_computed(sess, merged_summary,global_step=i)
if i%100 == 0:
saver.save(sess,logdir="/home/keith/tmp/",global_step=i)
tf.train.MonitoredTrainingSession()
tf.train.MonitoredTrainingSession(
master='',
is_chief=True,
checkpoint_dir=None,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=USE_DEFAULT,
save_summaries_steps=USE_DEFAULT,
save_summaries_secs=USE_DEFAULT,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=USE_DEFAULT,
summary_dir=None
)
-
is_chief
意思是在分布式集群中是否是主节点。 - 如果
save_summaries_steps
和save_summaries_secs
都是None的时候,则默认100个step保存一次。 -
config
: 一个tf.ConfigProto
类,用来配置session
。
一个使用的例子。
with monitored_session.MonitoredTrainingSession(
master=master,
is_chief=is_chief,
checkpoint_dir=logdir,
scaffold=scaffold,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
config=config,
max_wait_secs=max_wait_secs) as session:
loss = None
while not session.should_stop():
loss = session.run(train_op)
return loss
一个常用的小例子:
with tf.train.MonitoredTrainingSession(
checkpoint_dir = './Checkpoints',
hooks = [hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step),
tf.train.NanTensorHook(loss)]
save_checkpoint_steps = 100) as sess:
一个用于训练的完整的小例子:
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step),
tf.train.NanTensorHook(loss),
_LoggerHook()], # 将上面定义的_LoggerHook传入
config=tf.ConfigProto(
log_device_placement=False)) as sess:
coord = tf.train.Coordinator()
# 开启文件读取队列
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
while not sess.should_stop():
sess.run(train_op)
coord.request_stop()
coord.join(threads)
MonitoredTrainingSession继承自MonitoredSession
当MonitoredSession初始化的时候,会按顺序执行下面操作:
调用
hook
的begin()
函数,我们一般在这里进行一些hook内的初始化。比如在上面猫狗大战中的_LoggerHook
里面的_step
属性,就是用来记录执行步骤的,但是该参数只在本类中起作用。通过调用
scaffold.finalize()
初始化计算图创建会话
通过初始化
Scaffold
提供的操作(op)来初始化模型如果
checkpoint
存在的话,restore
模型的参数launches queue runners
调用
hook.after_create_session()
然后,当run()
函数运行的时候,按顺序执行下列操作:调用
hook.before_run()
调用TensorFlow的
session.run()
调用
hook.after_run()
返回用户需要的
session.run()
的结果如果发生了
AbortedError
或者UnavailableError
,则在再次执行run()
之前恢复或者重新初始化会话
最后,当调用close()
退出时,按顺序执行下列操作:调用
hook.end()
关闭队列和会话
阻止
OutOfRange
错误
需要注意的是:该类不是一个tf.Session()
,因为它不能被设置为默认会话,不能被传递给saver.save
,也不能被传递给tf.train.start_queue_runners
,这也解释了为什么在开启会话后我们必须手动调用tf.train.start_queue_runners()
各种Hook
-
tf.train.SummarySaverHook
:如果summary_writer没有给定,但是output_dir给定了那么就会创建一个writer
__init__(
save_steps=None,
save_secs=None,
output_dir=None,
summary_writer=None,
scaffold=None,
summary_op=None
)
-
chekpointSaverHook
:
saver_hook = tf.train.CheckpointSaverHook(
checkpoint_dir = model_dir,
save_steps = 100
)