tensorflow-模型保存和加载(一)

模型保存和加载(一)

TensorFlow的模型格式有很多种,针对不同场景可以使用不同的格式。

格式 简介
Checkpoint 用于保存模型的权重,主要用于模型训练过程中参数的备份和模型训练热启动。
GraphDef 用于保存模型的Graph,不包含模型权重,加上checkpoint后就有模型上线的全部信息。
SavedModel 使用saved_model接口导出的模型文件,包含模型Graph和权限可直接用于上线,TensorFlow和Keras模型推荐使用这种模型格式。
FrozenGraph 使用freeze_graph.py对checkpoint和GraphDef进行整合和优化,可以直接部署到Android、iOS等移动设备上。
TFLite 基于flatbuf对模型进行优化,可以直接部署到Android、iOS等移动设备上,使用接口和FrozenGraph有些差异。

在训练模型的时候需要保存模型的中间训练结果Checkpoint,以便下次迭代训练或者用作预测。Tensorflow针对这一需求提供了Saver类

1.Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值。
2.只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
3.为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

方式一:

模型保存:

通过下面的一段代码创建saver对象来管理模型中的变量(默认情况下是所有的变量,也可以自行选择)。

import tensorflow as tf

# save to file
W = tf.Variable([[1, 2, 8], [4, 5, 8]], dtype=tf.float32, name='weight')
b = tf.Variable([[1, 2, 8]], dtype=tf.float32, name='biases')

init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
        sess.run(init)
        save_path = saver.save(sess, "./my_net/save_net.ckpt")
        print ("save to path:", save_path)

模型加载:

用同一个Saver对象来恢复变量(需要把模型的结构重新定义一遍)。注意,当你从文件恢复变量时,不需要对它进行初始化,否则会报错。

import tensorflow as tf
import numpy as np

W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name='weight')
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name='biases')

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "./my_net/save_net.ckpt")
    print ("weights:", sess.run(W))
    print ("biases:", sess.run(b))

恢复训练:

import tensorflow as tf
import numpy as np
import os

#输入数据
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0,0.05, x_data.shape)
y_data = np.square(x_data)-0.5+noise

#输入层
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

#隐层
W1 = tf.Variable(tf.random_normal([1,10]))
b1 = tf.Variable(tf.zeros([1,10])+0.1)
Wx_plus_b1 = tf.matmul(xs,W1) + b1
output1 = tf.nn.relu(Wx_plus_b1)

#输出层
W2 = tf.Variable(tf.random_normal([10,1]))
b2 = tf.Variable(tf.zeros([1,1])+0.1)
Wx_plus_b2 = tf.matmul(output1,W2) + b2
output2 = Wx_plus_b2

#损失
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-output2),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

#模型保存加载工具
saver = tf.train.Saver()

#判断模型保存路径是否存在,不存在就创建
if not os.path.exists('tmp/'):
    os.mkdir('tmp/')

#初始化
sess = tf.Session()
if os.path.exists('tmp/checkpoint'): #判断模型是否存在
    saver.restore(sess, 'tmp/model.ckpt') #存在就从模型中恢复变量
else:
    init = tf.global_variables_initializer() #不存在就初始化变量
    sess.run(init)

#训练
for i in range(1000):
    _,loss_value = sess.run([train_step,loss], feed_dict={xs:x_data,ys:y_data})
    if(i%50==0): #每50次保存一次模型
        save_path = saver.save(sess, 'tmp/model.ckpt') #保存模型到tmp/model.ckpt,注意一定要有一层文件夹,否则保存不成功!!!
        print("模型保存:%s 当前训练损失:%s"%(save_path, loss_value))

这种方法不方便的在于,在载入模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

方式二:

不需重新定义网络结构的方法: tf.train.import_meta_graph

import_meta_graph(
    meta_graph_or_file,
    clear_devices=False,
    import_scope=None,
    **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的 y,则应该在训练阶段将它添加到collection中。(这个是不是不需要手动添加)

模型保存:

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # 保存checkpoint, 同时也默认导出一个meta_graph
        # graph名为'my-model-{global_step}.meta'.
        saver.save(sess, 'my-model', global_step=step)

模型加载:

checkpoint_file=tf.train.latest_checkpoint(checkpoint_directory)
graph=tf.Graph()

  with graph.as_default():
    session_conf = tf.ConfigProto(allow_safe_placement=True, log_device_placement =False)
    sess = tf.Session(config = session_conf)
    with sess.as_default():
      saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
      saver.restore(sess,checkpoint_file)

      # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
      y = tf.get_collection('pred_network')[0]

      # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
      input_x = graph.get_operation_by_name('input_x').outputs[0]
      keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

      # 使用y进行预测  
      sess.run(y, feed_dict={input_x:....,  keep_prob:1.0})

这里有两点需要注意的: 
一、 saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如

my-model-10000.meta     元模型文件,保存图的结构

my-model-10000.index

my-model-10000.data-00000-of-00001    权重文件

import_meta_graph时填的是meta文件名。权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用这个方法tf.train.latest_checkpoint(checkpoint_dir)获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

你可能感兴趣的:(tensorflow-模型保存和加载(一))