tensorflow存储和加载模型

存储模型

saver=tf.train.Saver()
with tf.Session() as sess:
    saver.save(sess, "./model/crack_capcha.model-1390")
保存模型文件结构图

保存模型后会在相应的目录下生成四个文件,文件的名字中crack_capcha.model-1390为保存模型是设置的文件名。

  1. checkpoint 文本文件,记录了模型文件的路径信息列表
  2. mcrack_capcha.model-1390.data-00000-of-00001网络权重信息
  3. crack_capcha.model-1390.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
  4. crack_capcha.model-1390.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)

加载模型

model_path = "./model"
saver = tf.train.import_meta_graph(model_path + '/crack_capcha.model-1390.meta')# 加载图结构
gragh = tf.get_default_graph()# 获取当前图,为了后续训练时恢复变量
# tensor_name_list = [tensor.name for tensor in gragh.as_graph_def().node]# 得到当前图中所有变量的名称
x = gragh.get_tensor_by_name('Placeholder:0')# 获取输入变量(占位符,由于保存时未定义名称,tf自动赋名称“Placeholder”)
keep_prob = gragh.get_tensor_by_name('Placeholder_2:0')# 获取dropout的保留参数

pred = gragh.get_tensor_by_name('Add_1:0')# 获取网络输出值
predict = tf.argmax(tf.reshape(pred, [-1, max_captcha, char_set_len]), 2)#对输出的值做进一步的处理,可替换成满足自身条件的任何处理

model_path = "./model"
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint(model_path))# 加载变量值
    print('finish loading model!')   
    text = sess.run(predict, feed_dict = {x:[image], keep_prob:1})#执行run方法,自动的将返回的结果pred进行计算获取predict

加载模型遇到问题

在加载模型的过程中,执行gragh.get_tensor_by_name('Add_1:0')方法时,一直没有找准这个Add_1:0的名字,所以一直出现下面的一个错误。

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_37' with dtype float and shape [?,9600]
     [[node Placeholder_37 (defined at :27) ]]
     [[node ArgMax_59 (defined at :36) ]]

说一下怎么找到的这最后一个名字,因为我最后输出的时候使用的是ADD方法因此找带有ADD的名字,最后在save操作之前找到了这个最后的ADD名字。

你可能感兴趣的:(tensorflow存储和加载模型)