1、神经网络模型的保存
2、神经网络模型的加载
一般神经网络训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,保存下来,可以提供给之后进行使用。TensorFlow提供了一个非常简单的API,即tf.train.Saver
类来保存和还原一个神经网络模型。下面代码给出了保存TensorFlow模型的方法:
import tensorflow as tf
#声明2个变量
v1=tf.Variable(tf.random_normal([1,2]),name='v1')
v2=tf.Variable(tf.random_normal([2,3]),name='v2')
saver=tf.train.Saver() #实例化saver对象,用于保存模型
with tf.Session() as sess:
init_op=tf.global_variables_initializer()#初始化全部变量
sess.run(init_op)
print('v1:',sess.run(v1)) #打印v1和v2的值一会读取之后对比
print('v2:',sess.run(v2))
my_model=saver.save(sess,'zhang/mymodel') #将模型保存当该文件中
print('my model saved in file:',my_model)
运行结果:
(tf) love@iZuf69ps3de0b3n4a50j7nZ:~/tf$ python one.py
v1: [[3.182594 0.8925665]]
v2: [[-0.15594512 -0.88531935 0.42543846]
[ 0.9683068 -0.82360727 1.177089 ]]
my model saved in file: zhang/mymodel
注意:这段代码中,通过
saver.save
函数将TensorFlow模型保存到了zhang/mymodel
文件中,这里代码中指定路径为"zhang/mymodel"
,也就是保存到了当前程序所在文件夹里面的zhang文件夹中。保存后在zhang
这个文件夹中会出现4个文件,如下图所示:TensorFlow会将计算图的结构和图上参数取值分开保存。在反向传播过程中,一般会间隔一定轮数保存一次神经网络模型,其中:保存当前图结构的
.meta文件、保存当前参数名字的.index文件,保存当前参数的.data文件,
checkpoint文件保存了一个目录下所有的模型文件列表。
2.1 重复定义图上的运算,恢复模型
在测试网络效果的时候,需要将训练好的神经网络模型加载,下面代码给出了加载tensoflow
模型的方法:可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?
import tensorflow as tf
#使用和保存模型中一样的方式来声明变量
v1=tf.Variable(tf.random_normal([1,2]),name='v1')
v2=tf.Variable(tf.random_normal([2,3]),name='v2')
saver=tf.train.Saver() #实例化saver对象
with tf.Session() as sess:
#把模型恢复到当前会话sess
saver.restore(sess, "zhang/mymodel") # 即将固化到硬盘中的Session从保存路径再读取出来
print('v1:',sess.run(v1)) #打印v1和v2的值一会读取之后对比
print('v2:',sess.run(v2))
print('model restored!')
运行结果:
(tf) love@iZuf69ps3de0b3n4a50j7nZ:~/tf$ python two.py
v1: [[-1.282081 -1.4206177]]
v2: [[ 0.03054714 -0.6736177 0.37405923]
[-0.03247713 0.83015114 -0.47652403]]
model restored!
注意:这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个
tf.train.Saver
类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。 也就是说使用TensorFlow完成了一次模型的保存和读取的操作。
2.2 不希望重复定义图上的运算,也可以直接加载已经持久化的图
import tensorflow as tf
# 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量
# 直接加载持久化的图
saver=tf.train.import_meta_graph('zhang/mymodel.meta') #注意后缀.meta
with tf.Session() as sess:
saver.restore(sess,'zhang/mymodel')
#通过张量的名称来获取张量
print('v1张量为:',sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))
print('v2张量为:',sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))
运行结果:
(tf) love@iZuf69ps3de0b3n4a50j7nZ:~/tf$ python three.py
v1张量为: [[-1.282081 -1.4206177]]
v2张量为: [[ 0.03054714 -0.6736177 0.37405923]
[-0.03247713 0.83015114 -0.47652403]]
注意:saver=tf.train.import_meta_graph('zhang/mymodel.meta') #注意和2.1中相比多加了一个后缀.meta,