tensorflow:加载预训练网络及获取operation、tensor值

下文只是个人笔记,不保证正确和详尽。

tensorflow的计算图包括tensor、operation。保存一张完整的计算图意味着同时保存两者,只需要一句话。但加载网络有痛处,在于无法当你不是作者的时候,无法快速得知网络里到底有哪些tensor、operation。

先说重点

计算图里的tensor和operation几乎结伴出现(我没验证过)。定义一个variable,则它也配备一个operation;设计一个operation,则它也有对应的输出作为tensor。所有调用格式都是:

eg
tensor name:index w:0
operation name w

index怎么取的还没搞懂。
如果没有名字(这取决于一开始训练有没有取名字),那么只能用其他接口查阅有哪些名字,再调用 ,下文会提。

准备

这是我测试用的网络。注释掉的部分可以用来做对比。
tf早期的图、权重保存方式和现在的不一样,现在是分3个文件保存的:.data .meta .index,另外还有一个记录文件checkpoint。详情谷歌。

def first():
    W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')
    b = tf.Variable([[0,1,2]],dtype = tf.float32)
    c = tf.Variable([[0,1,2]],dtype = tf.float32)
    #d = tf.add(b,c)
    #e = tf.placeholder(tf.float32, [2, 2],name='holder')
    init = tf.initialize_all_variables()
    saver = tf.train.Saver()
    with tf.Session() as sess:
            sess.run(init)
            #sess.run(init,feed_dict ={e:[[1, 2], [3, 4]]})
            save_path = saver.save(sess,"_save/model.ckpt")

获取tensor——方法一

#-*- coding:utf-8 -*-
import tensorflow as tf

def second():
    #tf-gpu,我在evga1080上遇到了不支持的数据类型,要加一个config
    config = tf.ConfigProto(allow_soft_placement = True)
    sess = tf.Session(config = config)
    #导入图结构。
    saver = tf.train.import_meta_graph('_save/model.ckpt.meta')
    #恢复权重值。model.ckpt是我的3个文件的前缀。本来应该从checkpoint导入前缀,我的checkpoint出了点问题。
    saver.restore(sess,'_save/model.ckpt')
    var = tf.global_variables()#全部调用
    for i in var:
        print i

以上代码会打印计算图中的所有variable,不知道为什么是重复打印两次。

获取tensor——方法二

#-*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

def third():
    checkpoint_path = "_save/model.ckpt"
    reader = pywrap_tensorflow.NewCheckpointReader("_save/model.ckpt")
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print("tensor_name: ", key)
        print(reader.get_tensor(key))#按名字调用

以上代码打印的依然是variable,只打印一遍。

补充
  • 如果把first()函数里的add和placeholder也加入网络,两种打印结果都不变,因为add和placeholder都是operation。
  • 以类似格式书写,可以恢复operation:
    给add/placeholder加上名字:d = tf.add(b,c,name='add'),再在加载网络后按以下两种方式调用:
    d = graph.get_operation_by_name("add")
    print d
    op = graph.get_operations()
    for i in op:
        print i

会发现都可以打印出add。

  • 事实上,如果你把get系列函数名称中的operation替换成tensor,再把用到名字的地方按开头的表格修改格式,会得到对应另一种输出。也就是“定义一个variable,则它也配备一个operation;设计一个operation,则它也有对应的输出作为tensor”。原理和tensorflow的设计有关, 建议查看权威资料。

你可能感兴趣的:(tensorflow)