TensorFlow应用.pb文件保存和加载模型方法及相关注意事项

其他参考链接:https://blog.csdn.net/guyuealian/article/details/82218092

 

一、.ckpt转.pb用于模型上线

.ckpt转.pb主要应用于将训练模型发布上线,.pb模型的跨平台和跨框架性能更好。这里由于在保存.pb模型前需要将模型变量freezing。在应用tensorflow训练模型时,输入数据的batch_size>1,直接保存.pb模型时会在inference阶段出现问题,所以需要从.ckpt转为.pb。在加载.ckpt时可以重新定义输入数据的batch_size=1,以解决该问题。应用步骤主要分为:

1.加载.ckpt并且.ckpt转.pb:

(1)定义图模型,在inference阶段加载.ckpt文件:

sample_graph = tf.Graph()
with sample_graph.as_default():
    input_data = tf.placeholder(tf.float32, shape=input_shape, name='input_data_gt')  # input_shape的batch_size维度为1
    output = sample_net(input_data)
    net_saver = tf.train.Saver()
sess = tf.Session(graph = sample_graph)
net_saver.restore(sess, model_path)

首先定义输入数据变量placehold数据类型,这里定义输入数据变量时将batch_size=1,再加载图模型,最后restore .ckpt模型。

(2).ckpt转.pb:

tf.train.write_graph(sess.graph_def, pb_dir, pb_name)
freeze_graph.freeze_graph(pb_path, '', False, model_path, nodes_to_be_saved, save/restore_all, 'save/Const:0', pb_path, False, '')

首先,tf.train.write_graph将图结构保存.pb文件中,再调用tensorflow中包装好的接口(freeze_graph)保存模型的输出变量。其中,模型的输入变量会根据图结构自动回溯保存到.pb文件中。 注:这种方式主要应用.ckpt转.pb,由于.pb已经将变量freeze化,这里需要将input的batch_size定义为1。 2.应用.pb加载模型及参数做inference:

with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
        gnet_output = sess.graph.get_tensor_by_name(node_to_be_loaded)
        out_put = sess.run(gnet_output, feed_dict={input_img: input_img})

通过get_tensor_by_name获取模型的输入变量和输出变量,输入变量在feed_dict中加载数据,输出变量用于接受模型结果。

二、.pb用于finetune

由于tensorflow没有提供类似于加载.ckpt文件的restore接口,在做.pb文件用于模型finetune时,需要将模型中的trainable variables全部保存下来,并且在加载.pb文件时需要根据变量名称将变量值一一赋值到模型中。 应用步骤分为保存trainable variables到.pb文件和从.pb文件加载trainable variables:

1.保存trainable variables到.pb文件:

var_list = tf.trainable_variables()
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, [var_list[i].name[:-2] for i in range(len(var_list))])
with tf.gfile.FastGFile(pb_path, mode='wb') as f:
    f.write(constant_graph.SerializeToString())

finetune需要加载可训练的参数即可,因此这里只要保存trainable variables即可。

2.从.pb文件加载trainable variables:

pb_para_dic = {} # 用于存储pb文件变量信息的字典结构 key-变量名,value-变量值
with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
        for item_var in var_list:
            pb_para_dic[item_var.name] = sess.run(sess.graph.get_tensor_by_name(item_var.name))
with tf.Session(graph=graph, config=tfconfig) as sess:
    var_list = tf.trainable_variables()
    for item_var in var_list:
         sess.run(tf.assign(item_var, pb_para_dic[item_var.name]))  # 将字典中的变量信息赋值图结构中

加载模型时需要重新定义一个新的图结构用于加载.pb文件中的权重,并且将权重keys和values放到一个字典中,然后在默认图结构中根据keys给session中的变量赋值。 注:在加载.pb模型时需要重新定义一个临时的graph域空间和临时的session,避免和图模型的定义空间冲突。

三、合并训练和测试模型到.pb文件

在合并训练和测试模型时,只需要将模型的trainable variables保存到.pb文件中,然后根据训练/测试代码建立的模型一一加载.pb文件中的trainable variables恢复模型。其中,应用方法与finetune的步骤基本一致。

应用步骤分为保存trainable variables到.pb文件,以及在训练和测试阶段,从.pb文件中一一加载trainable variables:

1.保存trainable variables到.pb文件:

var_list = tf.trainable_variables()
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, [var_list[i].name[:-2] for i in range(len(var_list))])
with tf.gfile.FastGFile(pb_path, mode='wb') as f:
    f.write(constant_graph.SerializeToString())

finetune需要加载可训练的参数即可,因此这里只要保存trainable variables即可。

2.训练/测试阶段从.pb文件加载trainable variables:

pb_para_dic = {}
with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")
    with tf.Session() as sess:
        for item_var in var_list:
            pb_para_dic[item_var.name] = sess.run(sess.graph.get_tensor_by_name(item_var.name))
with tf.Session(graph=graph, config=tfconfig) as sess:
    var_list = tf.trainable_variables()
    for item_var in var_list:
         sess.run(tf.assign(item_var, pb_para_dic[item_var.name]))

加载模型时需要重新定义一个新的图结构用于加载.pb文件中的权重,并且将权重keys和values放到一个字典中,然后在默认图结构中根据keys给session中的变量赋值。

注:在加载.pb模型时需要重新定义一个临时的graph域空间和临时的session,避免和图模型的定义空间冲突。

你可能感兴趣的:(tensorflow,tensorflow)