tensorflow 模型的保存和加载

为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型。

1. 保存模型

tensorflow提供了一个API可以方便的保存和还原神经网络的模型。这个API就是tf.train.saver类。

import tensorflow as tf

# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
# 声明tf.train.Saver()类用于保存模型
saver = tf.train.Saver()
# 加载保存了两个变量和的模型
with tf.Session() as sess:
    print("save here...")
    sess.run(init_op)
    # 保存模型到下面路径下
    saver.save(sess,"/Users/lilong/Desktop/tt/model.ckpt")
    print(sess.run(result))

运行结果:

save here...
[-1.6226364]

这里的代码实现了一个简单的加法功能,通过saver.save函数把模型保存到了相应的路径下,这里一定要注意第一次保存一定是saver.save,而不是saver.restore
虽然上面的模型保存路径只提供了一个,但是这个目录下一般会出现三个文件,这是因为tensorflow会将计算图的结构和图上的参数值分开保存。

  • model.ckpt.meta:保存了计算图的网路结构
  • model.ckpt.data.:保存了变量的取值
  • checkpoint:保存了一个目录下的所有的模型文件列表

2. 加载保存的模型

import tensorflow as tf

# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

# 加载模型的代码和保存模型的代码的区别是:没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来
#init_op = tf.global_variables_initializer()

saver = tf.train.Saver()
# 加载保存了两个变量和的模型
with tf.Session() as sess:
    print("Reading checkpoints...")
    # 加载已经保存的模型
    saver.restore(sess,"/Users/lilong/Desktop/tt/model.ckpt")
    print(sess.run(result))

这里要注意的是加载模型和保存模型的区别是:加载模型的代码没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。

上面是单独加载模型,当然也可以如下面这样保存好模型后直接加载:

import tensorflow as tf

# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
# 声明tf.train.Saver()类用于保存模型
saver = tf.train.Saver()
# 加载保存了两个变量和的模型
with tf.Session() as sess:
    print("save here...")
    sess.run(init_op)
    # 保存模型到下面路径下
    saver.save(sess,"/Users/lilong/Desktop/tt/model.ckpt")
    print(sess.run(result))


# 加载保存了两个变量和的模型
with tf.Session() as sess:
    print("Reading checkpoints...")
    # 加载已经保存的模型
    saver.restore(sess,"/Users/lilong/Desktop/tt/model.ckpt")
    print(sess.run(result))

运行结果:

save here...
[-1.6226364]
Reading checkpoints...
INFO:tensorflow:Restoring parameters from /Users/lilong/Desktop/tt/model.ckpt
[-1.6226364]

还可以这样加载已经持久化的模型:

import tensorflow as tf
#  直接加载持久化的图。
saver = tf.train.import_meta_graph("/Users/lilong/Desktop/tt/model.ckpt.meta")

with tf.Session() as sess:
    print('get here...')
    saver.restore(sess, "/Users/lilong/Desktop/tt/model.ckpt")
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

输出:

get here...
INFO:tensorflow:Restoring parameters from /Users/lilong/Desktop/tt/model.ckpt
[-1.6226364]

这里得到的是指定的张量的值。

4. 加载模型时给变量重命名

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 16 16:17:17 2018

@author: lilong
"""

import tensorflow as tf


# 保存计算两个变量和的模型

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result1 = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
# 加载保存了两个变量和的模型
with tf.Session() as sess:
    print("save here...")
    sess.run(init_op)
    # 保存模型到下面路径下
    saver.save(sess,"/Users/lilong/Desktop/qq/model.ckpt")
    print(sess.run(result1))

for variables in tf.global_variables(): 
    print ('variables_1:',variables.name)



# 这里声明的变量和已经保存的模型中的变量名称不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2")
result2 = v1 + v2

saver1 = tf.train.Saver({"v1": v1, "v2": v2})
#saver1 = tf.train.Saver()


# 加载保存了两个变量和的模型
with tf.Session() as sess:
    print("Reading checkpoints...")
    # 加载已经保存的模型
    saver1.restore(sess,"/Users/lilong/Desktop/qq/model.ckpt")
    print(sess.run(result2))

for variables in tf.global_variables(): 
    print ('variables_2:',variables.name)

运行:

save here...
[3.]
variables_1: v1:0
variables_1: v2:0
Reading checkpoints...
INFO:tensorflow:Restoring parameters from /Users/lilong/Desktop/qq/model.ckpt
[3.]
variables_2: v1:0
variables_2: v2:0
variables_2: other-v1:0
variables_2: other-v2:0

这里对变量v1,v2的名称进行了修改,所以如果直接使用tf.train.Saver()来保存默认的模型,那么程序就会报找不到变量的错误,因为模型保存时和加载时的名称不一致,这个时候可以使用字典把模型保存时的变量名和需要加载的变量联系起来。
这样的好处之一是方便使用变量的滑动平均值,在tensorflow中的每一个变量的滑动平均值是通过影子变量维护的,所以要获取变量的滑动平均值实际就是获取影子变量的值,如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时就不需要调用函数来获取滑动平均值了。

4. 保存滑动平均模型

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
# 在没有申请滑动平均值时只有一个变量
for variables in tf.global_variables(): 
    print('Before MovingAverage:',variables.name)

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
# 在申请滑动平均模型之后,tensorflow会自动生成一个影子变量:v/ExponentialMovingAverage:0
for variables in tf.global_variables(): 
    print ('After MovingAverage:',variables.name)


# 保存滑动平均模型
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。
    saver.save(sess, "model/model2.ckpt")
    print ('last:',sess.run([v, ema.average(v)])) # 输出:[10.0, 0.099999905]

# 通过变量重命名直接读取变量的滑动平均值,通过这个方法就可以用完全一样的代码来计算滑动平均模型的前向传播的结果
v_1 = tf.Variable(0, dtype=tf.float32, name="v")
# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v_1})
with tf.Session() as sess:
    saver.restore(sess, "model/model2.ckpt")
    print('here:',sess.run(v_1)) # 输出0.099999905,这个值就是原来模型中变量v的滑动平均值

运行结果:

Before MovingAverage: v:0
After MovingAverage: v:0
After MovingAverage: v/ExponentialMovingAverage:0
last: [10.0, 0.099999905]
INFO:tensorflow:Restoring parameters from model/model2.ckpt
here: 0.099999905

可以看到通过变量重命名直接读取变量的滑动平均值。

为了方便加载时重命名滑动平均变量,tensorflow提供了variables_to_restore()函数,来生成tf.train.Saver类需要的变量重命名字典:

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
# variables_to_restore()函数可以直接生成字典
print('here:',ema.variables_to_restore())

#saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
    print("Reading checkpoints...")
    saver.restore(sess, "model/model2.ckpt")
    print ('run:',sess.run(v))   

运行结果:

here: {'v/ExponentialMovingAverage': Variable 'v:0' shape=() dtype=float32_ref>}
Reading checkpoints...
INFO:tensorflow:Restoring parameters from model/model2.ckpt
run: 0.099999905

使用tf.train.Saver会保存运行tensorflow中程序所需要的全部信息,而某些情况下并不需要全部的信息,比如测试或离线预测时,只需知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要其他的一些信息,有时将变量取值和计算图分成不同的文件存储也不方便,于是有了convert_variables_to_constants函数,该函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个tensorflow图可以统一保存在一个文件中。
示例:

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    # 导出当前计算图的graphdef部分,只需要这一部分就可以完成从输入层到输出层的计算过程
    graph_def = tf.get_default_graph().as_graph_def()
    # 将图中的变量及其取值转化为常量,同时将图中不必要的节点。这里我们只关心程序中的某些计算节点,
    # 和这些无关的计算节点就没有计算并保存了。'add'是计算机节点名字
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
    # 将导入的模型存入文件
    with tf.gfile.GFile("model/combined_model.pb", "wb") as f:
           f.write(output_graph_def.SerializeToString())

# 通过下面的程序就可以直接计算定义的加法运算的结果,该方法可以用于迁移学习
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename = "model/combined_model.pb"
    # 读取保存的模型文件,并将文件解析成对应的graph protocol buffer
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # graph_def中保存的图加载到当前的图中。return_elements给出返回的张量的名称。
    # 这里的add不是计算机节点的名称,而是张量的名称,所以会是add:0
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print('run:',sess.run(result))

输出:

INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
run: [array([3.], dtype=float32)]

参考:《Tensorflow实战Google深度学习框架》

你可能感兴趣的:(Tensorflow)