tf.train.Saver() 模型保存和加载

1.文件说明

1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt"),生成4个文件:

a) .meta文件,Meta graph:
     这是一个协议缓冲区,它保存了完整的Tensorflow图形;即所有变量、操作、集合等。该文件以.meta作为扩展名。

b) .ckpt.index和.ckpt.meta文件,Checkpoint file

    这是一个二进制文件,它包含了所有的权重、偏差、梯度和其他所有变量的值。这个文件有一个扩展名.ckpt。现在,我们有两个文件,而不是单个.ckpt文件:

  1. mymodel.data-00000-of-00001    #.data文件是包含我们训练变量的文件,我们待会将会使用它。

  2. mymodel.index

c) checkpoint文件

与此同时,Tensorflow也有一个名为checkpoint的文件,它只保存的最新保存的checkpoint文件的记录。

2.保存模型:

在完成培训之后,我们希望将所有的变量和网络结构保存到一个文件中,以便将来使用。因此,在Tensorflow中,我们希望保存所有参数的图和值,我们将创建一个tf.train.Saver()类的实例。

saver = tf.train.Saver(max_to_keep=3)#max_to_keep不写默认为1

请记住,Tensorflow变量仅在会话中存在。因此,您必须在一个会话中保存模型,调用您刚刚创建的save方法。

saver.save(sess, 'my-test-model')

 例子

    #定义一个保存的实例
    saver =tf.train.Saver(max_to_keep=3)


    # sess.run(tf.global_variables_initializer())
    with tf.Session() as sess:
        sess.run(init)

        for step in range(501):
            sess.run(train_op)
            # 运行合并op
            summary = sess.run(merged)
            # lo,prediction =sess.run([loss,prediction])
            # print(lo,prediction)
            # 建一个事件文件,
            filewriter = tf.summary.FileWriter('./data/summary', graph=sess.graph)
            # 写入收集merged到事件文件
            filewriter.add_summary(summary, step)
            if step % 10 == 0:
                # print('%f,%.3f,%.3f'%(step,w.eval(),b.eval()))
                print(step, w.eval(), b.eval())

        #模型的保存,
        saver.save(sess,'tmp/ckpt/')#后面的斜杠要加上,不然容易出错

 

3.加载恢复网络:

a)创建网络

你可以通过编写python代码创建网络,以手工创建每一层,并将其作为原始模型。但是,如果你考虑一下,我们已经在.meta文件中保存了这个网络,我们可以使用tf.train.import()函数来重新创建这个网络:

saver = tf.train.import_meta_graph('my_test_model-1000.meta')

 

记住,import_meta_graph将在.meta文件中定义的网络附加到当前图。因此,这将为你创建图形/网络,但是我们仍然需要加载我们在这张图上训练过的参数的值。

b)载入参数

我们可以通过调用这个保护程序的实例来恢复网络的参数,它是tf.train.Saver()类的一个实例。

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))

 

 

手工构建之前的网络并加载模型 

def save():
    print('This is save')
    # build neural network
    tf_x = tf.placeholder(tf.float32, x.shape)  # input x
    tf_y = tf.placeholder(tf.float32, y.shape)  # input y
    l = tf.layers.dense(tf_x, 10, tf.nn.relu)          # hidden layer
    o = tf.layers.dense(l, 1)                     # output layer
    loss = tf.losses.mean_squared_error(tf_y, o)   # compute cost
    train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())  # initialize var in graph

    saver = tf.train.Saver()  # define a saver for saving and restoring

    for step in range(100):                             # train
        sess.run(train_op, {tf_x: x, tf_y: y})

    saver.save(sess, './params', write_meta_graph=False)  # meta_graph is not recommended

    # plotting
    pred, l = sess.run([o, loss], {tf_x: x, tf_y: y})
    plt.figure(1, figsize=(10, 5))
    plt.subplot(121)
    plt.scatter(x, y)
    plt.plot(x, pred, 'r-', lw=5)
    plt.text(-1, 1.2, 'Save Loss=%.4f' % l, fontdict={'size': 15, 'color': 'red'})


def reload():
    print('This is reload')
    # build entire net again and restore
    tf_x = tf.placeholder(tf.float32, x.shape)  # input x
    tf_y = tf.placeholder(tf.float32, y.shape)  # input y
    l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)          # hidden layer
    o_ = tf.layers.dense(l_, 1)                     # output layer
    loss_ = tf.losses.mean_squared_error(tf_y, o_)   # compute cost

    sess = tf.Session()
    # don't need to initialize variables, just restoring trained variables
    saver = tf.train.Saver()  # define a saver for saving and restoring
    saver.restore(sess, './params')

    # plotting
    pred, l = sess.run([o_, loss_], {tf_x: x, tf_y: y})
    plt.subplot(122)
    plt.scatter(x, y)
    plt.plot(x, pred, 'r-', lw=5)
    plt.text(-1, 1.2, 'Reload Loss=%.4f' % l, fontdict={'size': 15, 'color': 'red'})
    plt.show()


save()

# destroy previous net
tf.reset_default_graph()

reload()

 

reference:

https://blog.csdn.net/tan_handsome/article/details/79303269一个快速完整的教程,以保存和恢复Tensorflow模型。

 

你可能感兴趣的:(tensorflow)