在开始学习之前推荐大家可以多在FlyAI竞赛服务平台多参加训练和竞赛,以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例,支持算法能力变现以及快速的迭代算法模型。
附件一:sklearn上的用法
import tensorflow as tf
import numpy as np
W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,"save/model.ckpt")
import tensorflow as tf
import numpy as np
W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"save/model.ckpt")
2. 默认情况下:saver.save(sess,"save/model.ckpt")产生4个文件:
checkpoint文件保存最新的模型;
model.ckpt.data 以字典的形式保存权重偏置项等训练参数
model.ckpt.index:存储训练好的参数索引
model.ckpt.meta : 元文件(meta) 中保存了MetaGraphDef 的持久化数据,即模型数据,计算图的网络结构信息,完整的graph、variables、operation、collection。
3. 如何知道tensor的名字,最好是定义tensor的时候就指定名字,如上面代码中的name='w',如果你没有定义name,tensorflow也会设置name,只不过这个name就是根据你的tensor或者操作的性质。所以最好还是自己定义好name。
【说明:这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。】
tf.train.import_meta_graph
import_meta_graph(
meta_graph_or_file,
clear_devices=False,
import_scope=None,
**kwargs
)
这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。
比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下:
和1.1一样,保持不变
import tensorflow as tf
import numpy as np
# W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
# b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')
# saver = tf.train.Saver()
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph("save/model.ckpt.meta")
new_saver.restore(sess, "save/model.ckpt")
【个人理解:model.ckpt.meta : 保存了TensorFlow计算图的网络结构信息,import_meta_graph("save/model.ckpt.meta")这句拉取了结构,故不用重新定义。】
实现了 (y = x + b)当输入一个x 那么输出的结果y就等于输入x加上b。
3.1 保存
# Author:yifan
import os
import tensorflow as tf # 以下所有代码默认导入
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
# 保存模型路径
PATH = './models'
# 创建一个变量
one = tf.Variable(3.0)
# 创建一个占位符,在 Tensorflow 中需要定义 placeholder 的 type ,一般为 float32 形式
num = tf.placeholder(tf.float32,name='input')
# 创建一个加法步骤,注意这里并没有直接计算
sum = tf.add(num,one,name='output')
# 初始化变量,如果定义Variable就必须初始化
init = tf.global_variables_initializer()
# 创建会话sess
with tf.Session() as sess:
sess.run(init)
print(sess.run(sum, feed_dict={num: 5.0}))
# #保存SavedModel模型,使用以下三句
builder = tf.saved_model.builder.SavedModelBuilder(PATH)
signature = predict_signature_def(inputs={'input':num}, outputs={'output':sum})
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING],signature_def_map={'predict': signature})
builder.save()
说明:
---- 标签也可以选用系统定义好的参数,tf.saved_model.tag_constants.SERVING与 tf.saved_model.tag_constants.TRAINING等。
运行结果:8.0,和保存的模型:
注意:当前目录下不可以存在models文件夹,否则会报错。
# Author:yifan
import tensorflow as tf # 以下所有代码默认导入
PATH = './models'
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ["serve"], PATH)
#一种载入变量的方式:
in_x =tf.saved_model.loader.load(sess, ["serve"], PATH).signature_def['predict'].inputs['input'].name
#另一种载入变量的方式:
# in_x = sess.graph.get_tensor_by_name('input:0') #加载输入变量
y = sess.graph.get_tensor_by_name('output:0') #加载输出变量
scores = sess.run(y, feed_dict={in_x: 3.})
print(scores)
说明:
结果:6.0
传统的导入 需要用get_tensor_by_name , 这样就需要记录tensor的name熟悉,很麻烦。通过signature,我们可以指定变量的别名,方便存取。但如果我们拿到了别人的含有signature一个SavedModel模型而且并不知道"标签"那么怎么调用呢?
---Tensorflow官方已经为我们准备好了一个脚本,tensorflow下的saved_model_cli.py文件可以帮到。
我们可以'WIN+R'输入'cmd'然后回车打开你的CMD,然后指定路径到你的模型目录下,运行:
saved_model_cli show --dir=./ --all
打印出的信息中我们就可以看到模型的输入/输出的名称、数据类型、shape以及方法名称。
保存参数:
from sklearn.externals import joblib
joblib.dump((centres, des_list,img_features), "imgs_features.pkl", compress=3)
读取参数:
centres, des_list, img_features = joblib.load("imgs_features.pkl") #读取保存的特征
【1】
TensorFlow 模型保存/载入的两种方法blog.csdn.net【2】
【tensorflow】保存模型、再次加载模型等操作_I am what i am-CSDN博客_保存模型blog.csdn.net【3】
TensorFlow saved_model 模块blog.csdn.net【4】
Tensorflow学习笔记(二)模型的保存与加载(一 )blog.csdn.net更多精彩内容请访问FlyAI-AI竞赛服务平台;为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台;每周免费提供项目开源算法样例,支持算法能力变现以及快速的迭代算法模型。
挑战者,都在FlyAI!!!