由于神经网络训练比较复杂所以可能需要先保存训练好的模型,然后再需要的时候进行调用,下面介绍两种保存模型的方法:
保存代码,该方法保存的模型比较全,只要定义的变量均可获取,导入的模型与当前生成几乎具有一样的能力:
#定义占位符,具有名称的变量可以被在导入模型后获取
x = tf.placeholder(tf.float32, [None, 784], name ='x')
y_ = tf.placeholder(tf.int64, [None], name='y_')
#有些变量名难以定义,可以通过下面的方法保存
tf.add_to_collection('pred_network', y_conv)
tf.add_to_collection('pred_network', keep_prob)
#保存模型,目录model,前缀mnist_model
saver = tf.train.Saver()
with tf.Session() as sess:
#训练部分,省略
saver.save(sess, './model/mist_model')
加载:
with tf.Session() as sess:
model = tf.train.import_meta_graph('./model/mist_model.meta')
model.restore(sess, './model/mist_model')
#加载变量,注意变量名必须是定义过的:
#这部分因为是将变量存入一个集合中,所以需要注意顺序
y_conv = tf.get_collection('pred_network')[0]
keep_prob = tf.get_collection('pred_network')[1]
graph = tf.get_default_graph()
x = graph.get_operation_by_name('x').outputs[0]
y_ = graph.get_operation_by_name('y_').outputs[0]
保存代码,该方法只能保存一种类型的网络,函数根据定义的输出节点来确定网络的类型是评估或者分类,也就是根据输出节点,往前倒推,有关联的变量才存储。另外,不同的图之间不能运算,也就是说,加载的图中变量不能用于新的计算。
注:如果当前网络为分类网络,那么,即使之前训练的网络中包含用于评估的标签变量,该分类网络也不能导入该标签变量。
冻结模型也有两种方法:
from tensorflow.python.framework import graph_util
#变量名定义
x = tf.placeholder(tf.float32, [None, 784], name ='x')
with tf.Session() as sess:
#训练部分,省略
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['classifier/predicition'])
with tf.gfile.FastGFile("./model/outModel1.pb", mode='wb') as f:
f.write(constant_graph.SerializeToString())
该方法应该主要以命令行方式将存储的一系列模型文件转换成pb格式模型
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import graph_io
# 分别保存图和变量
saver = tf.train.Saver()
checkpoint_path = saver.save(sess, "./model/mist_model")
# 以文本方式存储所有节点信息
graph_io.write_graph(sess.graph, "./model/", "model.pb")
# 定义冻结图方法的参数
input_graph_path = os.path.join("./model/", "model.pb")
input_saver_def_path = ""
input_binary = False
output_node_names = "classifier/predicition"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join("./model/", "outModel2.pb")
clear_devices = False
input_meta_graph = "./model/mist_model.meta"
# 冻结图
freeze_graph.freeze_graph(
input_graph_path,
input_saver_def_path,
input_binary,
checkpoint_path,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph_path,
clear_devices,
"",
"",
input_meta_graph,
checkpoint_version=1)
加载:
graph = tf.Graph()
graph_def = tf.GraphDef()
with open('./model/outModel1.pb', "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
#变量加载,只能加载与本网络功能有关的变量。
#注意:变量名前必须包含import,否则,报错:“The name 'x' refers to an Operation not in the graph”
x = graph.get_operation_by_name('import/x').outputs[0]
keep_prob = graph.get_operation_by_name('import/dropout/keep_prob').outputs[0]
pred = graph.get_operation_by_name('import/classifier/predicition').outputs[0]
with tf.Session(graph=graph) as sess:
#进行计算