TensorFlow入门(六、模型的保存和载入)

保存模型

使用TensorFlow的saver()类先实例化一个saver对象,然后在session中通过saver的save方法将模型保存起来。代码示例如下:

#初始化所有变量
init = tf.global_variable_initializer()

#定义saver和保存路径
saver = tf.train.Saver()
saverdir = "save_path"

#启动Session
with tf.Session() as sess:
    sess.run(init)
    #使用saver的save方法保存
    saver.save(sess,saverdir + "file_name")

        其中,filename如果不存在,程序会自动创建。

打印模型中的内容

使用inspect_checkpoint包中的print_tensors_in_checkpoint_file方法将模型中的具体内容打印出来。代码示例如下:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
form tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

saverdir = "log/"
print_tensors_in_checkpoint_file(savedir + "linearmodel.cpkt",None,True)

保存模型的其他方法

使用saver()类保存模型时,可以在函数中放入参数来实现更高级的功能,如指定存储变量名字与变量的对应关系。代码示例如下:

W = tf.Variable(1.0,name = "weight")
b = tf.Variable(2.0,name = "bias")

saver = tf.train.Saver({'weight':W,'bias':b})
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess,savedir + "linearmodel.cpkt")
print_tensors_in_checkpoint_file(savedir + "linearmodel.cpkt",None,True)

载入模型

通过调用saver的restore()函数,从指定的路径找到模型文件,并覆盖到相关参数中。代码示例如下:

#初始化所有变量
init = tf.global_variable_initializer()

#定义saver和保存路径
saver = tf.train.Saver()
saverdir = "save_path"

#启动Session
with tf.Session() as sess:
    sess.run(init)
    #使用saver的restore方法载入模型
    print("x=0.2,z=",sess.run(z,feed_dict = {X:0.2}))

你可能感兴趣的:(TensorFlow入门,tensorflow,人工智能,python)