tensorflow将ckpt模型转为pb模型

获取原网络中的所有节点

在训练代码中定义好图之后加入以下代码:

for node in tf.get_default_graph().as_graph_def().node:

    print(node.name)

主要是要查看最后一个节点的名字

模型转化

不再重新建图时, 使用tf.train.import_meta_graph

def freeze_graph(input_checkpoint,output_graph):

    '''

    :param input_checkpoint:ckpt模型路径

    :param output_graph: pb模型保存路径

    '''

    output_node_names = " " # 填入第一步得到的最后一个节点名

    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:

        saver.restore(sess, input_checkpoint) #恢复图并得到数据

        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定

            sess=sess,

            input_graph_def=sess.graph_def,# 等于:sess.graph_def

            output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型

            f.write(output_graph_def.SerializeToString()) #序列化输出

        print("%d ops in the final graph." % len(output_graph_def.node)) # 统计图中总的操作节点数

或者修改前传代码,使用tf.train.Saver()

在前传代码里,restore模型

restorer = tf.train.Saver(tf.global_variables())

ckpt = tf.train.get_checkpoint_state(' ') # 填入ckpt模型所在文件夹路径

model_path = ckpt.model_checkpoint_path # 读取checkpoint文件里的第一行

with tf.Session() as sess:

    # Create a saver.

    sess.run(tf.local_variables_initializer())

    sess.run(tf.global_variables_initializer())

    try:

        restorer.restore(sess, model_path)

        print(model_path.split('/')[-1] + " restored!")

    except IOError:

        print("checkpoints not found.")

    output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定

        sess=sess,

        input_graph_def=sess.graph_def,  # 等于:sess.graph_def

        output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

    with tf.gfile.GFile(out_pb_path, "wb") as f:  # 保存模型

        f.write(output_graph_def.SerializeToString())  # 序列化输出

    print("%d ops in the final graph." % len(output_graph_def.node))

  # 统计图中总的操作节点数

从pb模型中读取节点

#coding:utf-8

import tensorflow as tf

from tensorflow.python.framework import graph_util

tf.reset_default_graph()  # 重置计算图

output_graph_path = 'model/model_tfnew.pb'

with tf.Session() as sess:

    tf.global_variables_initializer().run()

    output_graph_def = tf.GraphDef()

    # 获得默认的图

    graph = tf.get_default_graph()

    with open(output_graph_path, "rb") as f:

        output_graph_def.ParseFromString(f.read())

        _ = tf.import_graph_def(output_graph_def, name="")

        # 得到当前图有几个操作节点

        print("%d ops in the final graph." % len(output_graph_def.node))

        tensor_name = [tensor.name for tensor in output_graph_def.node]

        print(tensor_name)

        print('---------------------------')

        # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型

        summaryWriter = tf.summary.FileWriter('log_graph/', graph)

        for op in graph.get_operations():

            # print出tensor的name和值

            print(op.name, op.values())


参考:https://blog.csdn.net/u010397980/article/details/84889174

           https://blog.csdn.net/guyuealian/article/details/82218092

你可能感兴趣的:(tensorflow将ckpt模型转为pb模型)