1.文件说明
1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt"),生成4个文件:
a) .meta文件,Meta graph:
这是一个协议缓冲区,它保存了完整的Tensorflow图形;即所有变量、操作、集合等。该文件以.meta作为扩展名。b) .ckpt.index和.ckpt.meta文件,Checkpoint file:
这是一个二进制文件,它包含了所有的权重、偏差、梯度和其他所有变量的值。这个文件有一个扩展名.ckpt。现在,我们有两个文件,而不是单个.ckpt文件:
mymodel.data-00000-of-00001 #
.data文件是包含我们训练变量的文件,我们待会将会使用它。
mymodel.index
c) checkpoint文件
与此同时,Tensorflow也有一个名为checkpoint的文件,它只保存的最新保存的checkpoint文件的记录。
2.保存模型:
在完成培训之后,我们希望将所有的变量和网络结构保存到一个文件中,以便将来使用。因此,在Tensorflow中,我们希望保存所有参数的图和值,我们将创建一个tf.train.Saver()类的实例。
saver = tf.train.Saver(max_to_keep=3)#max_to_keep不写默认为1
请记住,Tensorflow变量仅在会话中存在。因此,您必须在一个会话中保存模型,调用您刚刚创建的save方法。
saver.save(sess, 'my-test-model')
例子
#定义一个保存的实例 saver =tf.train.Saver(max_to_keep=3) # sess.run(tf.global_variables_initializer()) with tf.Session() as sess: sess.run(init) for step in range(501): sess.run(train_op) # 运行合并op summary = sess.run(merged) # lo,prediction =sess.run([loss,prediction]) # print(lo,prediction) # 建一个事件文件, filewriter = tf.summary.FileWriter('./data/summary', graph=sess.graph) # 写入收集merged到事件文件 filewriter.add_summary(summary, step) if step % 10 == 0: # print('%f,%.3f,%.3f'%(step,w.eval(),b.eval())) print(step, w.eval(), b.eval()) #模型的保存, saver.save(sess,'tmp/ckpt/')#后面的斜杠要加上,不然容易出错
3.加载恢复网络:
a)创建网络
你可以通过编写python代码创建网络,以手工创建每一层,并将其作为原始模型。但是,如果你考虑一下,我们已经在.meta文件中保存了这个网络,我们可以使用tf.train.import()函数来重新创建这个网络:
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
记住,import_meta_graph将在.meta文件中定义的网络附加到当前图。因此,这将为你创建图形/网络,但是我们仍然需要加载我们在这张图上训练过的参数的值。
b)载入参数
我们可以通过调用这个保护程序的实例来恢复网络的参数,它是tf.train.Saver()类的一个实例。
with tf.Session() as sess: new_saver = tf.train.import_meta_graph('my_test_model-1000.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./'))
手工构建之前的网络并加载模型
def save(): print('This is save') # build neural network tf_x = tf.placeholder(tf.float32, x.shape) # input x tf_y = tf.placeholder(tf.float32, y.shape) # input y l = tf.layers.dense(tf_x, 10, tf.nn.relu) # hidden layer o = tf.layers.dense(l, 1) # output layer loss = tf.losses.mean_squared_error(tf_y, o) # compute cost train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss) sess = tf.Session() sess.run(tf.global_variables_initializer()) # initialize var in graph saver = tf.train.Saver() # define a saver for saving and restoring for step in range(100): # train sess.run(train_op, {tf_x: x, tf_y: y}) saver.save(sess, './params', write_meta_graph=False) # meta_graph is not recommended # plotting pred, l = sess.run([o, loss], {tf_x: x, tf_y: y}) plt.figure(1, figsize=(10, 5)) plt.subplot(121) plt.scatter(x, y) plt.plot(x, pred, 'r-', lw=5) plt.text(-1, 1.2, 'Save Loss=%.4f' % l, fontdict={'size': 15, 'color': 'red'}) def reload(): print('This is reload') # build entire net again and restore tf_x = tf.placeholder(tf.float32, x.shape) # input x tf_y = tf.placeholder(tf.float32, y.shape) # input y l_ = tf.layers.dense(tf_x, 10, tf.nn.relu) # hidden layer o_ = tf.layers.dense(l_, 1) # output layer loss_ = tf.losses.mean_squared_error(tf_y, o_) # compute cost sess = tf.Session() # don't need to initialize variables, just restoring trained variables saver = tf.train.Saver() # define a saver for saving and restoring saver.restore(sess, './params') # plotting pred, l = sess.run([o_, loss_], {tf_x: x, tf_y: y}) plt.subplot(122) plt.scatter(x, y) plt.plot(x, pred, 'r-', lw=5) plt.text(-1, 1.2, 'Reload Loss=%.4f' % l, fontdict={'size': 15, 'color': 'red'}) plt.show() save() # destroy previous net tf.reset_default_graph() reload()
reference:
https://blog.csdn.net/tan_handsome/article/details/79303269一个快速完整的教程,以保存和恢复Tensorflow模型。