tf.summary.scalar('accuracy',acc)
merge_summary = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(dir,sess.graph)
......(交叉熵、优化器等定义)
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
for step in xrange(training_step):
if step%1000==0:
saver.save(sess,checkpoint_dir,global_step=step)
train_summary = sess.run(merge_summary,feed_dict = {...})
train_writer.add_summary(train_summary,step)
主要是三个文件,一个是.data文件(网络的权值,偏置,操作),一个是.index文件(“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等)和.meta文件(图结构) 。我们主要看一下checkpoint文件,打开如下:
可以看到保存的都是路径名,看到第一行默认保存的是最新的模型路径。
def checkpoint_load(path):
print('Reading Checkpoints... .. .\n')
ckpt = tf.train.get_checkpoint_state(path)
print(ckpt)
print如下:
model_checkpoint_path: "model/mnist_model-49001"
all_model_checkpoint_paths: "model/mnist_model-45001"
all_model_checkpoint_paths: "model/mnist_model-46001"
all_model_checkpoint_paths: "model/mnist_model-47001"
all_model_checkpoint_paths: "model/mnist_model-48001"
all_model_checkpoint_paths: "model/mnist_model-49001"
所以可以看到tf.train.get_checkpoint_state(path)返回两个结果分别是:
ckpt.model_checkpoint_path
ckpt.all_model_checkpoint_paths
一般使用断点续训的时候我们只需要判断ckpt.model_checkpoint_path加载最新的模型即可:
if ckpt and ckpt.model_checkpoint_path:
ckpt_path = str(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(os.getcwd(), ckpt_path))
step = int(os.path.basename(ckpt_path).split('-')[1])
print("\nCheckpoint Loading Success! %s\n" % ckpt_path)
ckpt.model_checkpoint_path = "model/mnist_model-49001",step = int(os.path.basename(ckpt_path).split('-')[1]),得到49001,训练次数。
用法示例:
tf.summary.scalar('accuracy',acc)
merge_summary = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(dir,sess.graph)
......(交叉熵、优化器等定义)
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
step = checkpoint_load(sess,saver,checkpoint_dir)
for step in xrange(training_step):
step +=1
if step%1000==0:
saver.save(sess,checkpoint_dir,step)
train_summary = sess.run(merge_summary,feed_dict = {...})
train_writer.add_summary(train_summary,step)
checkpoint_load定义如下:
def checkpoint_load(sess,saver,path):
print('Reading Checkpoints... .. .\n')
ckpt = tf.train.get_checkpoint_state(path)
if ckpt and ckpt.model_checkpoint_path:
ckpt_path = ckpt.model_checkpoint_path
saver.restore(sess,os.path.join(os.getcwd(),ckpt_path))
step = int(os.path.basename(ckpt_path).split('-')[-1])
# 如果模型加载失败,返回step = 0
else:
step = 0
print('Checkpoint load failed')
return step
在测试的时候一般不需要返回step次数了,构建好网络之后直接调用checkpoint_load函数即可讲模型加载到当前图结构中,不需要返回值。有一点需要注意的是在全局初始化之后再加载参数,否则加载了模型参数又初始化之后没用。
完整的训练过程参考:https://blog.csdn.net/Li_haiyu/article/details/80846657