TensorFlow1 -1保存模型---tf.train.Saver,ckpt模型和PB 模型

Tensorflow 1 加载预训练模型和保存模型 (csdn非官方,补充)

文章目录

  • 1 Tensorflow模型文件结构
    • meta文件:计算图的结构,没有变量的值
    • ckpt文件
      • data包含所有变量的值,没有结构,index:内部需要的某种索引来正确映射前两个文件(meta和data),它通常不是必需的
    • checkpoint文件
  • 2 保存Tensorflow模型:session中保存,传入sess
    • 实例化 tf.train.Saver(),调用save()方法
    • 例子:
    • saver.restore函数
    • 存模型,不存图
      • 不存图
      • 保存最近的几个max_to_keep,每几小时保存一次keep_checkpoint_every_n_hours
      • 传入list/dict指定变量 ,以保存部分变量
  • 3 导入训练好的模型
    • 3.1 构造网络图tf.train.import_meta_graph()
    • 3.2 加载参数
      • tf.train.get_checkpoint_state()
      • tf.train.latest_checkpoint('./checkpoint_dir')
  • 4 使用恢复的模型
    • 4.1 使用已训练的模型,获取中间结果
    • 先创建一个模型
    • graph.get_tensor_by_name()
    • 执行步骤
    • 4.2已训练好模型.加入op和layers,训练新模型
    • 4.3恢复图的一部分,加入op,用于fine-tuning
  • 补充 with tf.Graph().as_default()
  • 测试
  • 5 查看模型的所有层的输入输出的tensor name
    • 典型代码,获取模型结构和图的函数
  • 6 TensorFlow二进制模型加载方法
  • .pb文件:不能够进行训练,仅能向前传播,而且我们在保存时需要指定节点名称
  • 从图上读取张量
    • 从当前图中获取对应张量:
    • 从图中获取节点信息

整篇文章主要是这里的内容,重新排版,并且也补充了一些东西.
参考1
参考2

1 Tensorflow模型文件结构

在checkpoint_dir目录下保存的文件结构如

|--checkpoint_dir
|    |--checkpoint
|    |--MyModel.meta
|    |--MyModel.data-00000-of-00001
|    |--MyModel.index

meta文件:计算图的结构,没有变量的值

MyModel.meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等。

ckpt文件

ckpt文件是二进制文件,保存了所有的weights、biases、gradients等变量。在tensorflow 0.11之前,保存在**.ckpt**文件中。0.11后,通过两个文件保存,如

data包含所有变量的值,没有结构,index:内部需要的某种索引来正确映射前两个文件(meta和data),它通常不是必需的

.data文件保存了当前参数名和值
MyModel.data-00000-of-00001

.index文件保存了辅助索引信息,正确映射前两个文件(meta和data)
MyModel.index

.data文件可以查询到参数名和参数值,使用下面的命令可以查询保存在文件中的全部变量{名:值}对,

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(os.path.join(savedir,savefile),None,True)

checkpoint文件

我们还可以看,checkpoint_dir目录下还有checkpoint文件,该文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model

2 保存Tensorflow模型:session中保存,传入sess

tensorflow 提供了tf.train.Saver类来保存模型,值得注意的是,在tensorflow中,变量是存在于Session环境中,也就是说,只有在Session环境下才会存有变量值,因此,保存模型时需要传入session:

实例化 tf.train.Saver(),调用save()方法

saver = tf.train.Saver()
saver.save(sess,"./checkpoint_dir/MyModel")

例子:

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './checkpoint_dir/MyModel')

执行后,在checkpoint_dir目录下创建模型文件如下:

checkpoint
MyModel.data-00000-of-00001
MyModel.index
MyModel.meta

saver.restore函数

给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载

saver.restore(sess,'./model/model.ckpt-0')
saver.restore(sess,ckpt.model_checkpoint_path)

如果想要在1000次迭代后,再保存模型,只需设置global_step参数即可:

saver.save(sess, './checkpoint_dir/MyModel',global_step=1000)

保存的模型文件名称会在后面加-1000,如下:

checkpoint
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-1000.meta

存模型,不存图

不存图

在实际训练中,我们可能会在每1000次迭代中保存一次模型数据,但是由于图是不变的,没必要每次都去保存,可以通过如下方式指定不保存图:

saver.save(sess, './checkpoint_dir/MyModel',global_step=step,write_meta_graph=False)

保存最近的几个max_to_keep,每几小时保存一次keep_checkpoint_every_n_hours

另一种比较实用的是,如果你希望每2小时保存一次模型,并且只保存最近的5个模型文件:

tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)

注意:tensorflow默认只会保存最近的5个模型文件,如果你希望保存更多,可以通过max_to_keep来指定

传入list/dict指定变量 ,以保存部分变量

如果我们不对tf.train.Saver指定任何参数,默认会保存所有变量。如果你不想保存所有变量,而只保存一部分变量,可以通过指定variables/collections。在创建tf.train.Saver实例时,通过将需要保存的变量构造list或者dictionary,传入到Saver中:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './checkpoint_dir/MyModel',global_step=1000)# 一千次迭代后保存

3 导入训练好的模型

tensorflow将图和变量数据分开保存为不同的文件。因此,在导入模型时,也要分为2步:构造网络图和加载参数

3.1 构造网络图tf.train.import_meta_graph()

一个比较笨的方法是,手敲代码,实现跟模型一模一样的图结构。其实,我们既然已经保存了图,那就没必要在去手写一次图结构代码。

直接加载图

saver=tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')

3.2 加载参数

仅仅有图并没有用,更重要的是,我们需要前面训练好的模型参数(即weights、biases等),本文第2节提到过,变量值需要依赖于Session,因此在加载参数时,先要构造好Session:

第一步,在session中加载图和最近一次保存的模型

tf.train.get_checkpoint_state()

参考1
参考2

ckpt = tf.train.get_checkpoint_state('./model/')    # 获得保存节点文件的状态
ckpt含有两个键值对,类似字典型,第一个键值对为:
model_checkpoint_path:./model/mnist_model-49001”
第二个键值对为:
all_model_checkpoint_paths:[./model/mnist_model-45001,./model/mnist_model-46001,./model/mnist_model-47001,./model/mnist_model-48001,./model/mnist_model-49001]
第二个键值对的value是一个列表,可以使用ckpt.all_model_checkpoint_paths[index]访问。

例子
在上面代码中,通过tf.train.get_checkpoint_state函数得到的相关模型文件名如下:
TensorFlow1 -1保存模型---tf.train.Saver,ckpt模型和PB 模型_第1张图片

对所有模型进行测试,得到:

TensorFlow1 -1保存模型---tf.train.Saver,ckpt模型和PB 模型_第2张图片

with tf.Session() as sess:            
            ckpt=tf.train.get_checkpoint_state('Model/')
            print(ckpt)
            if ckpt and ckpt.all_model_checkpoint_paths:
                #加载模型
                #这一部分是有多个模型文件时,对所有模型进行测试验证
                for path in ckpt.all_model_checkpoint_paths:
                    saver.restore(sess,path)                
                    global_step=path.split('/')[-1].split('-')[-1]
                    accuracy_score=sess.run(accuracy,feed_dict=validate_feed)
                    print("After %s training step(s),valisation accuracy = %g"%(global_step,accuracy_score))
                '''
                #对最新的模型进行测试验证
                saver.restore(sess,ckpt.model_checkpoint_paths)                
                global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                accuracy_score=sess.run(accuracy,feed_dict=validate_feed)
                print("After %s training step(s),valisation accuracy = %g"%(global_step,accuracy_score))
                '''
            else:
                print('No checkpoint file found')
                return
        #time.sleep(eval_interval_secs)
        return

# 连同图结构一同加载
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
    saver.restore(sess,ckpt.model_checkpoint_path)

# 只加载数据,不加载图结构,可以在新图中改变batch_size等的值
# 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./model/')
    saver.restore(sess,ckpt.model_checkpoint_path)

tf.train.latest_checkpoint(’./checkpoint_dir’)

加载最新的模型

import tensorflow as tf
with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')
  new_saver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir'))

此时,W1和W2加载进了图,并且可以被访问:

import tensorflow as tf
with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./checkpoint_dir'))
    print(sess.run('w1:0'))
##Model has been restored. Above statement will print the saved value

执行后,打印如下:

[ 0.51480412 -0.56989086]

4 使用恢复的模型

4.1 使用已训练的模型,获取中间结果

前面我们理解了如何保存和恢复模型,很多时候,我们希望使用一些已经训练好的模型,如prediction、fine-tuning以及进一步训练等。这时候,我们可能需要获取训练好的模型中的一些中间结果值,可以通过==graph.get_tensor_by_name(‘w1:0’)==来获取,注意w1:0是tensor的name。

假设我们有一个简单的网络模型,代码如下:

先创建一个模型

import tensorflow as tf


w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias") 

#定义一个op,用于后面恢复
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#创建一个Saver对象,用于保存所有变量
saver = tf.train.Saver()

#通过传入数据,执行op
print(sess.run(w4,feed_dict ={w1:4,w2:8}))
#打印 24.0 ==>(w1+w2)*b1

#现在保存模型
saver.save(sess, './checkpoint_dir/MyModel',global_step=1000)

graph.get_tensor_by_name()

执行步骤

  1. 用get_tensor_by_name()访问placeholder变量
  2. 定义feed_dict
  3. get_tensor_by_name()访问你要执行的op
  4. sess.run(op,feed_dict)

接下来我们使用graph.get_tensor_by_name()方法来操纵这个保存的模型。

import tensorflow as tf

sess=tf.Session()
#先加载图和参数变量
saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir'))


# 访问placeholders变量,并且创建feed-dict来作为placeholders的新值
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#接下来,访问你想要执行的op
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print(sess.run(op_to_restore,feed_dict))
#打印结果为60.0==>(13+17)*2

注意: 保存模型时,只会保存变量Variable的值,placeholder里面的值不会被保存

4.2已训练好模型.加入op和layers,训练新模型

如果你不仅仅是用训练好的模型,还要加入一些op,或者说加入一些layers并训练新的模型,可以通过一个简单例子来看如何操作:

import tensorflow as tf

sess = tf.Session()
# 先加载图和变量
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

# 访问placeholders变量,并且创建feed-dict来作为placeholders的新值
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

#接下来,访问你想要执行的op
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

# 在当前图中能够加入op
add_on_op = tf.multiply(op_to_restore, 2)

print (sess.run(add_on_op, feed_dict))
# 打印120.0==>(13+17)*2*2


4.3恢复图的一部分,加入op,用于fine-tuning

如果只想恢复图的一部分,并且再加入其它的op用于fine-tuning。只需通过graph.get_tensor_by_name()方法获取需要的op,并且在此基础上建立图,看一个简单例子,假设我们需要在训练好的VGG网络使用图,并且修改最后一层,将输出改为2,用于fine-tuning新数据:

......
......
saver = tf.train.import_meta_graph('vgg.meta')
# 访问图
graph = tf.get_default_graph() 
 
#访问用于fine-tuning的output
fc7= graph.get_tensor_by_name('fc7:0')
 
#如果你想修改最后一层梯度,需要如下
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()

num_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

# Now, you run this with fine-tuning data in sess.run()

get_tensor_by_name()
冒号后面的数字编号表示这个张量是计算节点上的第几个结果

冒号后面加0 指的是 该张量的第几个输出分支。 我们在通过 graph.get_tensor_by_name来获取这个张量值(也就是它的输出)的时候,一定要加冒号和数字的,只不过大部分tensor只有一个output,所以我们看到的大部分都是 :0
  所以,我们在定义tensor的时候,应该也需要点名name 比如 a = tensor.Variable( tf. normal…[] , name = ‘a’)

补充 with tf.Graph().as_default()

tensorflow下的Graph中 tf.Operation是一个node节点,而tf.Tensor是一个edge边

以tf.constant(0.0)为例,调用tf.constant(0.0)创建一个单独的tf.Operation,生成值42.0,将其添加到默认图形,并返回一个表示常数的值的tf.Tensor

在没有特别说明的情况下,程序中定义的tf.Operation均是添加进入default_graph中。若使用

graph1=tf.Graph()	# 声明tf.Graph()的一个类实例,即获取一个graph
with graph1.as_default():
	#完成graph1这个计算图中的tf.Operation的定义,即将在这个with所调用的上下文管理器中定义的tf.Operation添加进入声明的tf.Graph实例graph1中
	pass
 

将graph设置为default_graph,所以可以认为,所有的tf.Operation均是添加到default_graph进行的,而这个default_graph是可以设置的。

再说说这个with graph1.as_default():
with会启动一个上下文管理器。所谓上下文管理器,就是在程序执行前将上文中当前所需要的资源准备好,并在结束时被系统回收。

with graph1.as_default()的含义:
是 在这个with启动的上下文管理器中将graph1设置为default_graph,在with启动的上下文管理器内部所定义的tf.Operation则会添加进入当前的default_graph中,也就是graph1中。在整个with代码块结束后,default_graph将会重新设置为之前的,属于全局的graph。

测试

# -*- coding:utf-8 -*-
import tensorflow as tf
 
v1 = tf.Variable([0.0], name='v1')
v2 = tf.constant([1.0], name='v2')
add = tf.add(v1,v2, name='add')
 
graph1 = tf.Graph()
 
with graph1.as_default():
    v4 = tf.Variable([3.0], name='v4')
 
with tf.Session() as sess1:
    sess1.run(tf.global_variables_initializer())
    print sess1.run(add)
 
with tf.Session(graph=tf.get_default_graph()) as sess2:
    sess2.run(tf.global_variables_initializer())
    print sess2.run(v1)
 
g_op = graph1.get_operations()
print 'the operations in graph1>>> \n', g_op
print '\n'
d_op = tf.get_default_graph().get_operations()
print 'the operations in default graph>>> \n',d_op

输出结果如下:

[ 1.]
[ 0.]
the operations in graph1>>> 
[<tf.Operation 'v4/initial_value' type=Const>, <tf.Operation 'v4' type=VariableV2>, <tf.Operation 'v4/Assign' type=Assign>, <tf.Operation 'v4/read' type=Identity>]
 
the operations in default graph>>> 
[<tf.Operation 'v1/initial_value' type=Const>, <tf.Operation 'v1' type=VariableV2>, <tf.Operation 'v1/Assign' type=Assign>, <tf.Operation 'v1/read' type=Identity>, <tf.Operation 'v2' type=Const>, <tf.Operation 'add' type=Add>, <tf.Operation 'init' type=NoOp>, <tf.Operation 'init_1' type=NoOp>]
 

5 查看模型的所有层的输入输出的tensor name

典型代码,获取模型结构和图的函数

TensorFlow1 -1保存模型---tf.train.Saver,ckpt模型和PB 模型_第3张图片

import os
import re
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
model_exp = "20180402-114759"
 
def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files)==0:
        raise load_modelValueError('No meta file found in the model directory (%s)' % model_dir)
    elif len(meta_files)>1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    ckpt = tf.train.get_checkpoint_state(model_dir)     # 通过checkpoint文件找到模型文件名
    if ckpt and ckpt.model_checkpoint_path:
        # ckpt.model_checkpoint_path表示模型存储的位置,不需要提供模型的名字,它回去查看checkpoint文件
        ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
        return meta_file, ckpt_file
 
    meta_files = [s for s in files if '.ckpt' in s]
    max_step = -1
    for f in files:
        step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
        if step_str is not None and len(step_str.groups())>=2:
            step = int(step_str.groups()[1])
            if step > max_step:
                max_step = step
                ckpt_file = step_str.groups()[0]
    return meta_file, ckpt_file
 
 
meta_file, ckpt_file = get_model_filenames(model_exp)
 
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
reader = pywrap_tensorflow.NewCheckpointReader(os.path.join(model_exp, ckpt_file))
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    # print(reader.get_tensor(key))
 
with tf.Session() as sess:
    saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file))
    saver.restore(tf.get_default_session(),
                  os.path.join(model_exp, ckpt_file))
    print(tf.get_default_graph().get_tensor_by_name("Logits/weights:0"))

后续文章参考

6 TensorFlow二进制模型加载方法

这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作

# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
    # 二进制读取模型文件
    with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
        # 新建GraphDef文件,用于临时载入模型中的图
        graph_def = tf.GraphDef()
        # GraphDef加载模型中的图
        graph_def.ParseFromString(f.read())
        # 在空白图中加载GraphDef中的图
        tf.import_graph_def(graph_def,name='')
        # 在图中获取张量需要使用graph.get_tensor_by_name加张量名
        # 这里的张量可以直接用于session的run方法求值了
        # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
        self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
        self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name   in self.layer_operation_names]

这节是关于tensorflow的Freezing,字面意思是冷冻,可理解为整合合并;整合什么呢,就是将模型文件和权重文件整合合并为一个文件,主要用途是便于发布。

tensorflow在训练过程中,通常不会将权重数据保存的格式文件里(这里我理解是模型文件),反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,我们可以将tf的图和参数文件整合进一个后缀为pb的二进制文件中,由于整合过程回将变量转化为常量,所以我们在日后读取模型文件时不能够进行训练,仅能向前传播,而且我们在保存时需要指定节点名称。

.pb文件:不能够进行训练,仅能向前传播,而且我们在保存时需要指定节点名称

可以保存您的整个图表(元+数据)

转换后的graph_def对象转换为二进制数据(graph_def.SerializeToString())后,写入pb即可。

import tensorflow as tf
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2
 
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './tmodel/test_model.ckpt')
    gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), ['add'])
with tf.gfile.GFile('./tmodel/model.pb', 'wb') as f:
    f.write(gd.SerializeToString())

直接查看gd文件

node {
  name: "v1"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 1
          }
        }
        float_val: 1.0
      }
    }
  }
}
……
node {
  name: "add"
  op: "Add"
  input: "v1/read"
  input: "v2/read"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
library {
}

从图上读取张量

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'  # 瓶颈层输出张量名称
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'  # 输入层张量名称
MODEL_DIR = './inception_dec_2015'  # 模型存放文件夹
MODEL_FILE = 'tensorflow_inception_graph.pb'  # 模型名
 
 
# 加载模型
# with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),'rb') as f:   # 阅读器上下文
with open(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:  # 阅读器上下文
    graph_def = tf.GraphDef()  # 生成图
    graph_def.ParseFromString(f.read())  # 图加载模型
# 加载图上节点张量(按照句柄理解)
bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(  # 从图上读取张量,同时导入默认图
    graph_def,
    return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

从当前图中获取对应张量:

这个就是很普通的情况,从我们当前操作的图中获取某个张量,用于feed啦或者用于输出等操作,API也很简单,用法如下:

g.get_tensor_by_name('import/pool_3/_reshape:0')

g表示当前图句柄,可以简单的使用 g = tf.get_default_graph() 获取。

从图中获取节点信息

g = tf.get_default_graph()
print(g.as_graph_def().node)

这个操作将返回图的构造结构。从这里,对比前面的代码,我们也可以了解到:graph_def 实际就是图的结构信息存储形式,我们可以将之还原为图(二进制模型加载代码中展示了),也可以从图中将之提取出来(本部分代码)。

TensorFlow迁移学习_他山之石,可以攻玉

『cs231n』通过代码理解风格迁移

你可能感兴趣的:(tensorflow入门到删库,TensorFlow,保存模型,python,tensorflow,深度学习)