https://blog.csdn.net/YiRanZhiLiPoSui/article/details/81143166
参考入门文章:
https://blog.csdn.net/u012436149/article/details/53341372
给出了简单的完整流程,便于入门理解
https://www.jianshu.com/p/7490ebfa3de8
tensorflow官网出的Supervisor介绍 的中文翻译版:长期训练好帮手
https://www.tensorflow.org/versions/r1.1/programmers_guide/supervisor
tensorflow官网出的Supervisor介绍
https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
官方的Supervisor接口文档。不过缺乏完整的例子。
在不使用Supervisor
的时候,我们的代码经常是这么组织的
variables
...
ops
...
summary_op
...
merge_op = tf.summary.merge_all()
saver
init_op
with tf.Session() as sess:
writer = tf.summary.FileWriter()
sess.run(init)
saver.restore()
for ...:
train
merged_summary = sess.run(merge_op)
writer.add_summary(merged_summary,i)
saver.save
使用一个logdir目录 来同时保存 模型图 和 权重参数
sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op,summary_op=None) #logdir用来保存checkpoint和summary
注意有个参数是summary_op
如果没有summary_op=None,则使用Supervisor自带的summary服务
使用sv = tf.train.Supervisor() 会自动初始化。
无参数也可以,最好加上logdir,同时,两个logdir可以不同
import tensorflow as tf
tf.reset_default_graph()
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
logs_path='./logaa'
'''不需要初始化'''
#init_op = tf.global_variables_initializer()
#sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op) #logdir用来保存checkpoint和summary
'''这样也可以,最好加上logdir'''
sv = tf.train.Supervisor(logdir=logs_path) #这样也可以
with sv.managed_session() as sess: #会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
for i in range(71):
update_ = sess.run(update)
print(update_)
# if i % 10 == 0:
# merged_summary = sess.run(merged_summary_op)
# sv.summary_computed(sess, merged_summary)
if i%10 == 0:
sv.saver.save(sess,logs_path+'/model',global_step=i)
如果有summary_op=None,则需自建summary服务
import tensorflow as tf
tf.reset_default_graph()
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
logs_path='./logaa/'
tf.summary.scalar('a', a)
init_op = tf.global_variables_initializer()
merged_summary_op = tf.summary.merge_all()
sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op,summary_op=None) #logdir用来保存checkpoint和summary
with sv.managed_session() as sess: #会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
for i in range(1000):
update_ = sess.run(update)
print(update_)
if i % 10 == 0:
merged_summary = sess.run(merged_summary_op)
sv.summary_computed(sess, merged_summary)
if i%100 == 0:
sv.saver.save(sess,logs_path,global_step=i)
一个完整的例子
# -*- coding: utf-8 -*-
import tensorflow as tf
tf.reset_default_graph()
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
##### 构建图结构
# 定义输入:x和y
x = tf.placeholder(tf.float32, [None, 784], name='input_x')
y_ = tf.placeholder(tf.float32, [None, 10], name='input_y')
# 定义权重参数
W = tf.Variable(tf.truncated_normal([784, 10], stddev=0.1), name='weights')
b = tf.Variable(tf.constant(0.1, shape=[10]), name='bias')
# 定义模型
y_output = tf.nn.softmax(tf.matmul(x, W) + b)
# 定义交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_output))
# 监控交叉熵
tf.summary.scalar('loss', cross_entropy)
# tf.summary.scalar('loss', cross_entropy, collections=['loss'])
# 定义优化器和训练器
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# 定义准确率的计算方式
# 取预测值和真实值 概率最大的标签
correct_prediction = tf.equal(tf.argmax(y_output,1), tf.argmax(y_,1))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
##### 构建会话
# 定义log保存路径
logs_path = 'logsbbb/'
# 定义summary node集合
merged_summary_op = tf.summary.merge_all()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# 定义Supervisor
sv = tf.train.Supervisor(logdir=logs_path, init_op=tf.global_variables_initializer(), summary_op=None)
with sv.managed_session(config=config) as sess :
# 超参数
ITERATION = 1000 +1
BATCH_SIZE = 64
ITERATION_SHOW = 100
for step in range(ITERATION) :
# 执行训练op
batch = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})
if step%ITERATION_SHOW == 0:
# 计算当前训练样本的准确率
merged_summary, accuracy = sess.run([merged_summary_op, accuracy_op], feed_dict={x: batch[0], y_: batch[1]})
sv.summary_computed(sess, merged_summary, global_step=step)
# 输出当前准确率
print("step %d, accuarcy:%.4g" % (step, accuracy))
# 保存模型
sv.saver.save(sess, logs_path, global_step=step)