tensorflow笔记:网络模型的保存和读取

目录

1、神经网络模型的保存

2、神经网络模型的加载


1、神经网络模型的保存

一般神经网络训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,保存下来,可以提供给之后进行使用。TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。下面代码给出了保存TensorFlow模型的方法:

import tensorflow as tf

#声明2个变量
v1=tf.Variable(tf.random_normal([1,2]),name='v1')
v2=tf.Variable(tf.random_normal([2,3]),name='v2')
saver=tf.train.Saver()		#实例化saver对象,用于保存模型
with tf.Session() as sess:
	init_op=tf.global_variables_initializer()#初始化全部变量
	sess.run(init_op)
	print('v1:',sess.run(v1)) #打印v1和v2的值一会读取之后对比
	print('v2:',sess.run(v2))
	my_model=saver.save(sess,'zhang/mymodel')	#将模型保存当该文件中
	print('my model saved in file:',my_model)

运行结果:

(tf) love@iZuf69ps3de0b3n4a50j7nZ:~/tf$ python one.py 
v1: [[3.182594  0.8925665]]
v2: [[-0.15594512 -0.88531935  0.42543846]
 [ 0.9683068  -0.82360727  1.177089  ]]
my model saved in file: zhang/mymodel

注意:这段代码中,通过saver.save函数将TensorFlow模型保存到了zhang/mymodel文件中,这里代码中指定路径为"zhang/mymodel",也就是保存到了当前程序所在文件夹里面的zhang文件夹中。保存后在zhang这个文件夹中会出现4个文件,如下图所示:

tensorflow笔记:网络模型的保存和读取_第1张图片

TensorFlow会将计算图的结构和图上参数取值分开保存。在反向传播过程中,一般会间隔一定轮数保存一次神经网络模型,其中:保存当前图结构的.meta文件、保存当前参数名字的.index文件,保存当前参数的.data文件,checkpoint文件保存了一个目录下所有的模型文件列表。

2、神经网络模型的加载

2.1 重复定义图上的运算,恢复模型

在测试网络效果的时候,需要将训练好的神经网络模型加载,下面代码给出了加载tensoflow模型的方法:可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?

import tensorflow as tf

#使用和保存模型中一样的方式来声明变量
v1=tf.Variable(tf.random_normal([1,2]),name='v1')
v2=tf.Variable(tf.random_normal([2,3]),name='v2')
saver=tf.train.Saver() #实例化saver对象
with tf.Session() as sess:
	#把模型恢复到当前会话sess
	saver.restore(sess, "zhang/mymodel") # 即将固化到硬盘中的Session从保存路径再读取出来
	print('v1:',sess.run(v1)) #打印v1和v2的值一会读取之后对比
	print('v2:',sess.run(v2))
	print('model restored!')

运行结果:

(tf) love@iZuf69ps3de0b3n4a50j7nZ:~/tf$ python two.py 
v1: [[-1.282081  -1.4206177]]
v2: [[ 0.03054714 -0.6736177   0.37405923]
 [-0.03247713  0.83015114 -0.47652403]]
model restored!

注意:这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。 也就是说使用TensorFlow完成了一次模型的保存和读取的操作。

2.2 不希望重复定义图上的运算,也可以直接加载已经持久化的图

import tensorflow as tf

# 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量
# 直接加载持久化的图
saver=tf.train.import_meta_graph('zhang/mymodel.meta')	#注意后缀.meta
with tf.Session() as sess:
	saver.restore(sess,'zhang/mymodel')
	#通过张量的名称来获取张量
	print('v1张量为:',sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))
	print('v2张量为:',sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))

运行结果:

(tf) love@iZuf69ps3de0b3n4a50j7nZ:~/tf$ python three.py 
v1张量为: [[-1.282081  -1.4206177]]
v2张量为: [[ 0.03054714 -0.6736177   0.37405923]
 [-0.03247713  0.83015114 -0.47652403]]

注意:saver=tf.train.import_meta_graph('zhang/mymodel.meta')    #注意和2.1中相比多加了一个后缀.meta,

你可能感兴趣的:(Linux,Ubuntu,Tensorflow,Deep,Learning,Tensorflow笔记)