[TOC]
https://blog.csdn.net/lovelyaiq/article/details/78646401 还不错
保存过程
import tensorflow as tf
w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")
#这里将b1改为placeholder,让用户输入,而不是写死
#b1= tf.Variable(2.0,name="bias")
b1= tf.placeholder(tf.float32, name='bias')
w3 = tf.add(w1,w2)
#记住要定义name,后面需要用到
out = tf.multiply(w3,b1,name="out")
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 这里需要填入输出tensor的名字
graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)
使用过程
import tensorflow as tf
with tf.Session() as sess:
with open('./checkpoint_dir/graph.pb', 'rb') as f:
graph_def = tf.GraphDef() #计算的图,和固化的时候的sess.graph_def是一个
graph_def.ParseFromString(f.read())
#下面"执行pb"中同样有类似做法
output = tf.import_graph_def(graph_def, input_map={'bias:0':4.}, return_elements=['out:0'])
print(sess.run(output))
你是否可以在之前图的结构上构建新的网络?当然,您可以通过graph.get_tensor_by_name()方法访问适当的操作,并在此基础上构建图。这是一个真实的例子。在这里,我们使用元图加载一个vgg预训练的网络,并在最后一层中将输出的数量更改为2,以便对新数据进行微调。
......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
# Now, you run this with fine-tuning data in sess.run()
def snapshot(self, sess, iter):
net = self.net
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
# Store the model snapshot
filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
filename = os.path.join(self.output_dir, filename)
self.saver.save(sess, filename)
print('Wrote snapshot to: {:s}'.format(filename))
# Also store some meta information, random state, etc.
nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
nfilename = os.path.join(self.output_dir, nfilename)
# current state of numpy random
st0 = np.random.get_state()
# current position in the database
cur = self.data_layer._cur
# current shuffled indexes of the database
perm = self.data_layer._perm
# current position in the validation database
cur_val = self.data_layer_val._cur
# current shuffled indexes of the validation database
perm_val = self.data_layer_val._perm
# Dump the meta info
with open(nfilename, 'wb') as fid:
pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)
return filename, nfilename
def from_snapshot(self, sess, sfile, nfile):
print('Restoring model snapshots from {:s}'.format(sfile))
self.saver.restore(sess, sfile)
print('Restored.')
# Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
# tried my best to find the random states so that it can be recovered exactly
# However the Tensorflow state is currently not available
with open(nfile, 'rb') as fid:
st0 = pickle.load(fid)
cur = pickle.load(fid)
perm = pickle.load(fid)
cur_val = pickle.load(fid)
perm_val = pickle.load(fid)
last_snapshot_iter = pickle.load(fid)
np.random.set_state(st0)
self.data_layer._cur = cur
self.data_layer._perm = perm
self.data_layer_val._cur = cur_val
self.data_layer_val._perm = perm_val
return last_snapshot_iter
#查找!!!ckpt文件在哪
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
1.用tensorboard
python tensorflow/python/tools/import_pb_to_tensorboard.py \
--model_dir resnetv1_50.pb --log_dir /tmp/tensorboard
2.或者这么用
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename ='PATH_TO_PB.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def)
LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
.meta文件保存了当前图结构
.index文件保存了当前参数名
.data文件保存了当前参数值
meta文件保存图结构,weights等参数保存在data文件中。也就是说,图和参数数据时分开保存的。说的更直白一点,就是meta文件中没有weights等数据。但是,值得注意的是,meta文件会保存常量。我们只需将data文件中的参数转为meta文件中的常量即可!
1.加载图结构和参数,也就是加载ckpt.meta 和ckpt.data
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
saver.restore(sess,.ckpt)
for op in tf.get_default_graph().get_operations():
print(op.name,op.values())
print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0,feed_dict={'Placeholder:0': imgs}))
2.只加载数据
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./model/')
saver.restore(sess,ckpt.model_checkpoint_path)
# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
# 二进制读取模型文件
with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
# 新建GraphDef文件,用于临时载入模型中的图
graph_def = tf.GraphDef()
# GraphDef加载模型中的图,即用这个来解析读入的二进制模型文件
graph_def.ParseFromString(f.read())
# 在空白图中加载GraphDef中的图,相当于把pb加载到了一个新定义的图中,所以需要新建一个tf.graphDef()
tf.import_graph_def(graph_def,name='')
# 在图中获取张量需要使用graph.get_tensor_by_name加张量名
# 这里的张量可以直接用于session的run方法求值了
# 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]
# tensorflow object_detection api中这样用过!!!!!!
ops=tf.get_default_graph().get_operations()
#op.outputs是每个操作节点的所有输出,op.outpupts.name是相应的名字
all_tensor_names={output.name for op in ops for output in op.outputs}
tensor_dict={}
for key in [ 'num_detections', 'detection_boxes', 'detection_scores','detection_classes', 'detection_masks'
]:
tensor_name=key+':0'#第一个输出变量的名字
if tensor_name in all_tensor_names:
tensor_dict[key]=tf.get_default_graph().get_tensor_by_name(tensor_name)
执行PB文件
import tensorflow as tf
import numpy as np
import PIL.Image as Image
from skimage import io, transform
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")#import到当前图
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
input_x = sess.graph.get_tensor_by_name("input:0")#当前图中获取输入输出tensor名字
print input_x
out_softmax = sess.graph.get_tensor_by_name("softmax:0")
print out_softmax
out_label = sess.graph.get_tensor_by_name("output:0")
print out_label
img = io.imread(jpg_path)
img = transform.resize(img, (224, 224, 3))
#执行
img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})
print "img_out_softmax:",img_out_softmax
prediction_labels = np.argmax(img_out_softmax, axis=1)
print "label:",prediction_labels
recognize("vgg16/picture/dog/dog3.jpg", "vgg16/vggs.pb")
或者简单些:
sess = tf.Session()
#将保存的模型文件解析为GraphDef
model_f = gfile.FastGFile("model.pb",'rb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_f.read())
#在读取模型文件获取变量的值的时候,我们需要指定的是张量的名称而不是节点的名称
c = tf.import_graph_def(graph_def,return_elements=["add:0"])
print(sess.run(c))
#[array([ 11.], dtype=float32)]
4.保存为PB格式模型
定义运算过程
通过 get_default_graph().as_graph_def() 得到当前图的计算节点信息
通过 graph_util.convert_variables_to_constants 将相关节点的values固定
通过 tf.gfile.GFile 进行模型持久化
# coding=UTF-8
import tensorflow as tf
import shutil
import os.path
from tensorflow.python.framework import graph_util
output_graph = "model/pb/add_model.pb"
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
# predictions = tf.greater(_y, 50, name="predictions") #比50大返回true,否则返回false
predictions = tf.add(_y, 10,name="predictions") #做一个加法运算
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})
graph_def = tf.get_default_graph().as_graph_def() #得到当前的图的 GraphDef 部分,通过这个部分就可以完成重输入层到输出层的计算过程,不是新建一个tf.graphDef(),而是用当前图的。
#convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存。在保存模型文件的时候,我们只是导出了GraphDef部分,GraphDef保存了从输入层到输出层的计算过程。在保存的时候,通过convert_variables_to_constants函数来指定保存的节点名称而不是张量的名称,“add:0”是张量的名称而"add"表示的是节点的名称。
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess,
graph_def,
["predictions"] #指定的pb文件的输出
)
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
#或者用:
tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)
print("%d ops in the final graph." % len(output_graph_def.node))
print (predictions)
# for op in tf.get_default_graph().get_operations(): 打印模型节点信息
# print (op.name)
4 ckpt转换成PB格式
通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化
# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util
MODEL_DIR = "model/pb"
MODEL_NAME = "frozen_model.pb"
if not tf.gfile.Exists(MODEL_DIR): #创建目录
tf.gfile.MakeDirs(MODEL_DIR)
def freeze_graph(model_folder):
checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型保存路径
output_node_names = "predictions" #原模型输出操作节点的名字
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.
graph = tf.get_default_graph() #获得默认的图
input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前的图
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢复图并得到数据
print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 测试读出来的模型是否正确,注意这里传入的是输出 和输入 节点的 tensor的名字,不是操作节点的名字
output_graph_def = graph_util.convert_variables_to_constants( #模型持久化,将变量值固定
sess,
input_graph_def,
output_node_names.split(",") #如果有多个输出节点,以逗号隔开
)
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
for op in graph.get_operations():
print(op.name, op.values())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令行解析,help是提示符,type是输入的类型,
# 这里运行程序时需要带上模型ckpt的路径,不然会报 error: too few arguments
aggs = parser.parse_args()
freeze_graph(aggs.model_folder)
# freeze_graph("model/ckpt") #模型目录