免责声明:本文仅代表个人观点,如有错误,请读者自己鉴别;如果本文不小心含有别人的原创内容,请联系我删除;本人心血制作,若转载请注明出处
1、保存模型为.ckpt文件
saver = tf.train.Saver()
save_path = saver.save(sess, model_path)# 保存模型 其中model_path为模型保存的文件
保存后的模型有四个文件 checkpoint、SFCN.ckpt.data-00000-of-00001、SFCN.ckpt.index、SFCN.ckpt.meta
2、保存event, event为事件的保存路径
summary_writer = tf.summary.FileWriter(event, graph=sess.graph)
在命令窗口中输入 tensorboard --logdir==事件路径
在浏览器中可以可视化tensorboard,可以可视化tensorboard,查看图中有哪些节点
3、将ckpt文件转化为.pb文件,注意output_node_names = "Placeholder,Placeholder_2,keep_probabilty,conv2d_transpose_3" 这一行要写入所有点节点,注意节点之间不可以加空格
import tensorflow as tf from tensorflow.python.framework import graph_util def freeze_graph(input_checkpoint, output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路径 :return: ''' # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 # 直接用最后输出的节点,可以在tensorboard中查找到,tensorboard只能在linux中使用 output_node_names = "Placeholder,Placeholder_2,keep_probabilty,conv2d_transpose_3" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() # 获得默认的图 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图 with tf.Session() as sess: saver.restore(sess, input_checkpoint) # 恢复图并得到数据 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 sess=sess, input_graph_def=input_graph_def, # 等于:sess.graph_def output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 f.write(output_graph_def.SerializeToString()) # 序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点 input_checkpoint = "./checkpoint/SFCN.ckpt" # 输入的ckpt文件位置 output_graph = "node.pb" # 输出节点的文件名 freeze_graph(input_checkpoint, output_graph)
4、调用.pb文件,注意y_in = sess.graph.get_tensor_by_name("Placeholder:0") images_in = sess.graph.get_tensor_by_name("Placeholder_2:0") keep_probability_in = sess.graph.get_tensor_by_name("keep_probabilty:0") logits_out = sess.graph.get_tensor_by_name("conv2d_transpose_3:0")
这四个节点要与第3步中的节点保持一致,但是要在后面加入“:0”,如第3步中节点为“Placeholder”,第4步要写为“Placeholder:0”
global graph graph = tf.get_default_graph() with graph.as_default(): output_graph_def = tf.GraphDef() with open(model_path, "rb") as f: output_graph_def.ParseFromString(f.read()) tf.import_graph_def(output_graph_def, name="") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) y_in = sess.graph.get_tensor_by_name("Placeholder:0") images_in = sess.graph.get_tensor_by_name("Placeholder_2:0") keep_probability_in = sess.graph.get_tensor_by_name("keep_probabilty:0") logits_out = sess.graph.get_tensor_by_name("conv2d_transpose_3:0") batch_size = 1 testTime = 0 predictLabel = tf.zeros(trainLabel.shape) predictLabel = sess.run(predictLabel) for i in range(0, smallImage): realbatch_array, real_labels, real_index = getNext_batch(trainData, trainLabel, trainIndex, i) testStart = time.time() yy = sess.run(logits_out, feed_dict={images_in: realbatch_array, y_in: real_labels, keep_probability_in: 1.0}) predictLabel[i, ...] = yy testEnd = time.time() testTime1 = testEnd - testStart testTime += testTime1