工作中经常使用的是.ckpt,最近在研究tensorflow serving所以需要将模型转化为SavedModel格式。而有时模型平台调用又需要.pb模型,所以对这三种文件进行了解。
checkpoint文件:b包含最新的和所有的文件地址
.data文件:包含训练变量的文件
.index文件:描述variable中key和value的对应关系
.meta文件:保存完整的网络图结构
使用这种方法保存模型时会保存成上面这四个文件,重新加载模型时通常只会用到.meta文件恢复图结构然后用.data文件把各个变量的值再加进去。
saver=tf.train.Saver(max_to_keep=5) #表示保存最近的几个模型,设置为None或者0 就是保存全部的模型。此处max_to_keep=5意思就是保存最近的5个模型
saver.save(sess,'D:/model',global_step=epoch)
创建一个saver,调用save方法将当前sess会话中的图和变量等信息保存到指定路径,global_step代表当前的轮数,设置之后会在文件名后面缀一个"-epco"
saver=tf.train.import_meta_graph('model/model-0720-4.meta') #恢复计算图结构
saver.restore(sess, tf.train.latest_checkpoint("model/")) #恢复所有变量信息
#现在sess中已经恢复了网络结构和变量信息了,接下来可以直接用节点的名称来调用:
print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})
#或者采用:
graph = tf.get_default_graph()
input_x = graph.get_tensor_by_name('x:0')
input_y=graph.get_tensor_by_name('y:0')
op=graph.get_tensor_name('op:0')
print(sess.run(op,feed_dict={input_x:2,input_y:3)
.ckpt方式保存模型,这种模型文件是依赖 TensorFlow 的,只能在其框架下使用
.pb文件里面保存了图结构+数据,加载模型时只需要这一个文件就好。
constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op'])
with tf.gfile.FastGFile('D:/pycharm files/model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
with tf.gfile.FastGFile(pb_file_path, 'rb') as f:
graph_def = tf.GraphDef() # 生成图
graph_def.ParseFromString(f.read()) # 图加载模型
tf.import_graph_def(graph_def, name='')
#接下来与前面的相同可以直接用节点的名称来调用:
print(sess.run('op:0',feed_dict={'x:0':2,'y:0':3})
#或者采用:
graph = tf.get_default_graph()
input_x = graph.get_tensor_by_name('x:0')
input_y=graph.get_tensor_by_name('y:0')
op=graph.get_tensor_name('op:0')
print(sess.run(op,feed_dict={input_x:2,input_y:3)
谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小。
加载一个pb文件之后再对其进行微调(也就是将这个pb文件的网络作为自己网络的一部分),然后再保存成pb文件,后一个pb网络会包含前一个pb网络。
在传入的目录下会有一个pb文件和一个variables文件夹:
builder = tf.saved_model.builder.SavedModelBuilder(path)
builder.add_meta_graph_and_variables(sess,['cpu_server_1'])
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ['cpu_server_1'], pb_file_path+'savemodel')
#接下来可以直接使用名字或者get_tensor_by_name后再进行使用
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
op = sess.graph.get_tensor_by_name('op:0')
ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
saved_model模块主要用于TensorFlow Serving,目的是要实现inference的代码统一。详细可点击参考https://blog.csdn.net/thriving_fcl/article/details/75213361
参考https://www.jianshu.com/p/06548e3e8f4b
参考https://www.jianshu.com/p/451c46bd9287
参考https://www.cnblogs.com/biandekeren-blog/p/11876032.html
参考文献
https://www.cnblogs.com/biandekeren-blog/p/11876032.html