fully_connected_feed代码说明

代码来源:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/fully_connected_feed.py

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py


对fully_connected_feed.py代码进行说明,包括模型保存和加载(可用于模型在线预测与更新),tensorboard生成。

代码流程图如下:

fully_connected_feed代码说明_第1张图片


tensorboard的生成

包括4步:
1. 绑定summary ops:tf.summary.scalar('loss', loss)
2. 建立summary tensor:summary = tf.summary.merge_all()
3. 实例化SummaryWriter,以输出summaries and the Graph:summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
4. 更新事件文件:
        summary_str = sess.run(summary, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

其中输出graph需要建立不同的ops,如:
with tf.name_scope('hidden1'):
    weights = tf.Variable(
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
        name='weights')
则该weights则位于hidden1/weights

模型保存与加载

1. 保存模型

包括2步:
1)创建saver: saver = tf.train.Saver()
2)保存模型:saver.save(sess, checkpoint_file, global_step=step)
其中checkpoint_file是模型的路径,通过
checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
两部分组成,其中保存的模型包括四个文件,分别是:
checkpoint,model.ckpt-1999.data-00000-of-00001,model.ckpt-1999.index,model.ckpt-1999.meta
其中global_step标识了模型的保留时的训练步数,每个包括的模型在checkpoint文件中都会有记录。如果不设置global_step参数,则保留的模型文件中没有标识步数。
细节参考官方文档:https://www.tensorflow.org/api_docs/python/tf/train/Saver,包括保留最近的N个checkpoint文件

2. 加载
也是2步:
1)创建saver:saver = tf.train.Saver()
2)恢复模型:saver.restore(sess, checkpoint_file)
其中checkpoint_file是之前保存模型的地址,但要注意:如果之前保存时带有global_step参数,保存时模型地址也要带有参数,具体模型地址可以通过查看模型文件中的checkpoint文件得知。如果需要恢复制定的变量,可以在创建saver时输入变量列表,细节参考官方文档。


你可能感兴趣的:(fully_connected_feed代码说明)