tensorflow中用saver保存模型

  我们在用tensorflow训练模型时,可能需要训练很长很长一段时间,为了方便下次使用,应该将模型保存起来。在sklearn中,我们可以使用pickle模块进行模型保存;而在tensorflow中,我们可以使用它自带的Saver()类进行模型的保存。

(一)Saver类

  Saver类是用于保存和恢复变量的。它有将变量保存到checkpoint和从checkpoint中恢复变量的操作。

  Checkpoints是一个二进制文件,它的属性值和tensor变量值一一对应。最好的检查checkpoints内容的方法就是用一个Saver去加载它。

  Saver可以自动的为chackpoint文件进行计数。这可以让你在训练模型时,保存多个checkpoint(通过计数来区分)。例如你可以通过训练的epoch来标识你的checkpoint文件。为了防止过分使用内存,你可以为saver设置最多保存的checkpoint文件数量。

  你可以通过为save()函数传入global_step参数值来标识checkpoint文件例如:

saver.save(sess, 'my-model', global_step=0)           ==>filename: 'my-model-0'
saver.save(sess, 'my-model', global_step=1000)        ==>filename: 'my-model-1000'

属性:

last_checkpoints
    当前所有保存的checkpoint文件的名字的list集合。
    你可以将这个返回的文件名list的任意一个元素作为restore()函数的参数,用于恢复指定的checkpoint。

returns:
    返回checkpoints文件名列表,从最旧到最新的排序。

主要方法:

1.__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None)

    Saver类的构造函数。

    Args:
        var_list:需要保存的参数列表,可以dict或list形式。如果这个参数为None,则默认保存所有可以保存的对象。一般使用缺省值即可。
        max_to_keep:保存的checkpoint文件的最大数量。默认为只保存最后5个。
        keep_checkpoint_every_n_hours:多久保存一次checkpoint文件,默认10000小时每次。
        其他参数不常用。


2.save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True,
    strip_default_attrs=False)

    保存变量。
    这个方法用来保存变量,它需要一个session参数来指明哪个图。保存的参数必须已经被初始化过了。

    args:
        sess:保存变量需要的session
        save_path:checkpoint文件保存的路径。
        global_step:如果指定了,则会将这个数字添加到save_path后面,用于唯一标识checkpoint文件。
        latest_filename:和save_path在同一个文件夹中,用于最后一个checkpoint文件的命名。默认为checkpoint。
        其他不常用。

3.restore(
    sess,
    save_path)
    从save_path中恢复模型的参数。
    它需要一个session,需要恢复的参数不需要初始化,因为恢复本身就是一种初始化变量的方法。而参数save_path
    就是save()函数产生的文件的路径名。

    args:
        sess:一个session
        save_path:保存的路径

(二)使用举例

1.保存模型:
  使用Saver保存模型的参数时,一定要将saver = tf.train.Saver定义在你保存的的参数定义之后,即定义在需要的tf.Variable之后。定义在saver之后的参数无法被保存,切记切记!!!

  在下面的例子中,我们生成了y = (x - 1) ^ 2 - 2的样本,然后加上了一些噪音,试着用tensorflow训练出拟合该曲线的参数。

__author__ = 'liuwei'

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


h = 1
v = -2

#prepare data
x_train = np.linspace(-2, 4, 201)                        #x样本
noise = np.random.randn(*x_train.shape) * 0.4            #噪音
y_train = (x_train - h) ** 2 + v + noise                 #y样本

n = x_train.shape[0]

x_train = np.reshape(x_train, (n, 1))                    #重塑
y_train = np.reshape(y_train, (n, 1))

#画出产生的数据的形状
'''
plt.rcParams['figure.figsize'] = (10, 6)
plt.scatter(x_train, y_train)
plt.xlabel('x_train')
plt.ylabel('y_train')
plt.show()
'''
#create variable
X = tf.placeholder(tf.float32, [1])                      #两个占位符,x和y
Y = tf.placeholder(tf.float32, [1])

h_est = tf.Variable(tf.random_uniform([1], -1, 1))       #定义需要训练的参数,在saver之前定义
v_est = tf.Variable(tf.random_uniform([1], -1, 1))

saver = tf.train.Saver()                                 #保存模型参数的saver

value = (X - h_est) ** 2 + v_est                         #拟合的曲线

loss = tf.reduce_mean(tf.square(value - Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    for epoch in range(100):                             #100个epoch
        for (x, y) in zip(x_train, y_train):

            sess.run(optimizer, feed_dict={X: x, Y: y})
        #保存checkpoint
        saver.save(sess, './model_iter', global_step=epoch)


    #saver the final model
    saver.save(sess, './final_model')                    #最后一个epoch对应的checkpoint
    h_ = sess.run(h_est)
    v_ = sess.run(v_est)

    print(h_, v_)

运行结果如下:
tensorflow中用saver保存模型_第1张图片

2.恢复参数
  恢复参数时,我们只需要定义保存的Variable,不需要初始化,因为恢复过程其实就是一种初始化。恢复参数的代码如下:

__author__ = 'liuwei'

import tensorflow as tf 
import numpy as np 

h_est = tf.Variable(tf.random_uniform([1], -1, 1))     #只定义,没有初始化
v_est = tf.Variable(tf.random_uniform([1], -1, 1))


saver = tf.train.Saver()                      #saver类

path = './final_model'                        #要恢复的checkpoint路径

with tf.Session() as sess:
    saver.restore(sess, path)                 #恢复参数

    print(sess.run(h_est), sess.run(v_est))

运行结果为:
tensorflow中用saver保存模型_第2张图片

完整代码:https://github.com/liuwei1206/tensorflow-study/tree/master/2.saver/test

你可能感兴趣的:(数据挖掘与机器学习)