Tensorflow模型持久化与恢复

 Tensorflow模型
简单点说,一个tensorflow模型包含了神经网络的结构(graph)和通过训练得到的一系列神经网络的参数。
神经网络的结构(graph)即神经网络的节点(nodes)及其流图(flow),节点是一系列张量,每个张量中运行一个op,流图是这些节点之间的运算关系,神经网络的参数包括训练得到的神经网络各层的权重(weights)和偏置(biases)以及程序中用到的其他变量。
要了解如何保存(持久化,frezee)一个tensorflow模型,首先要了解一下两个概念:

Protocol buffers 协议缓冲区

说到Tensorflow中model的保存,就不得不提Protocol buffers ,Protocol buffers 是tensorflow中保存数据的协议,所有TensorFlow的文件格式都是基于Protocol Buffers的,概括说来,Protocol buffers 是一种语言中立的,平台中立的,可扩展的串行化结构化数据的方式,Protocol buffers实现了这样一种协议,你可以在文本文件中定义数据结构,使用Protocol buffers编译器编译之后,可以把文本文件中定义的数据结构生成C,Python,Java和其他语言的类。Protocol buffers文本文件结构跟xml,json等文件结构类似,其优点是文件更小,序列化读取,保存速度更快,操作更简单,官方文档中说,实现同等功能Protocol buffers文件比xml文件小3~10倍,序列化读取,保存速度快20~100倍。
通过在.proto文件中Protocol buffers message类型来指定希望将序列化信息结构化的方式。每个Protocol buffers message是包含一系列名称 - 值对的信息的小逻辑记录。下面是一个.proto文件的一个非常基本的例子,它定义了一个包含有关人员信息的message:

syntax = "proto2";  
  
package tutorial;  
  
message Person {  
  required string name = 1;  
  required int32 id = 2;  
  optional string email = 3;  

  enum PhoneType {  
    MOBILE = 0;  
    HOME = 1;  
    WORK = 2;  
  }  
  
  message PhoneNumber {  
    required string number = 1;  
    optional PhoneType type = 2 [default = HOME];  
  }  
  
  repeated PhoneNumber phones = 4;  
}  
  
message AddressBook {  
  repeated Person people = 1;  
}

通过Protocol buffers编译器编译成C++,会生成类似下面的代码:

class Person  
{  
// name  
  inline bool has_name() const;  
  inline void clear_name();  
  inline const ::std::string& name() const;  
  inline void set_name(const ::std::string& value);  
  inline void set_name(const char* value);  
  inline ::std::string* mutable_name();  
  
  // id  
  inline bool has_id() const;  
  inline void clear_id();  
  inline int32_t id() const;  
  inline void set_id(int32_t value);  
  
  // email  
  inline bool has_email() const;  
  inline void clear_email();  
  inline const ::std::string& email() const;  
  inline void set_email(const ::std::string& value);  
  inline void set_email(const char* value);  
  inline ::std::string* mutable_email();  
  
  // phones  
  inline int phones_size() const;  
  inline void clear_phones();  
  inline const ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >& phones() const;  
  inline ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >* mutable_phones();  
  inline const ::tutorial::Person_PhoneNumber& phones(int index) const;  
  inline ::tutorial::Person_PhoneNumber* mutable_phones(int index);  
  inline ::tutorial::Person_PhoneNumber* add_phones();  
}

Protocol buffers 文件有txt和二进制两种保存格式,文件后缀分别为.pbtxt和.pb,以txt格式保存的文件是可读的。

MetaGraph(元图)

MetaGraph是一个Protocol buffers,tensorflow通过MetaGraph来记录计算图中的节点信息以及运行计算图中节点所需要的元数据,通俗点说就是,MetaGraph包含了这个神经网络的结构和设计这个神经网络所用到的所有变量。
MetaGraph包含一个GraphDef和所有与graph中计算相关的元数据,用于图形的长期存储。 MetaGraph包含继续训练,执行评估或在以前训练过的graph上运行推断所需的信息。
MetaGraph中包含的信息用MetaGraphDef协议缓冲区来表示。简而言之,一个MetaGraph包含了神经网络的结构--graph,也包含了运行这个graph相关的参数(权重,偏置和其他变量),它包含以下字段:
--MetaInfoDef 用于元信息,如版本和其他用户信息。
--GraphDef 用于描述graph。
--SaverDef Saver定义相关信息。
--CollectionDef 进一步描述了模型其他组件的map,如Variables,tf.train.QueueRunner等。为了使Python对象能够从MetaGraphDef序列化,Python类必须实现to_proto()和from_proto()方法 ,并使用register_proto_function将其注册到系统。


因此,要保存一个tensorflow模型,需要保存一下两方面的信息:

①神经网络的结构---MetaGraph
MetaGraph以.meta格式的文件保存,它保存了神经网络的结构(节点及其流图)和神经网络中用到的所有参数,我们称之为神经网络变量(variables),注意:这里所说的变量是变量名,而没有变量值
②神经网络参数值---checkpoint文件
如上所述,这些参数包括训练得到的神经网络各层的权重(weights)和偏置(biases)以及程序中用到的其他变量的值。
tensorflow用checkpoint文件来保存神经网络的所有变量值,文件后缀名.cpkt。0.11之前的版本,通常只有一个.cpkt文件,新版tensorflow使用了两个文件:
model.cpkt.data
model.cpkt.index
其中model.cpkt.data文件中保存了神经网络的所有变量值。

那么,tensorflow中怎么保存得到上述两种文件呢?
最简单的方法,是使用tensorflow的tf.train.Saver类。使用这个类,默认情况下,它会保存tensorflow程序的所有信息。

tf.train.Saver

1.定义:tensorflow/python/training/saver.py
2.功能:保存tensorflow模型graph结构,将变量保存到checkpoint,或者从checkpoint恢复变量,
3.checkpoint 检查点
checkpoint 是一系列以特定格式保存将变量名映射到张量值的二进制文件。
Saver可以使用提供的计数器自动对检查点文件名进行编号。 这可以让您在训练模型时,在不同的训练次数中保留多个检查点。 例如,用训练次数值给检查点文件名编号。 为避免磁盘溢出,savers自动管理检查点文件。 例如,他们只能保留N个最近的文件,或者每N个训练时间保留一个检查点。
4.属性与方法
tf.train.Saver.last_checkpoints  返回排序后的检查点文件列表
构造函数:

__init__(
    var_list=None,#要保存或恢复的变量字典或列表,如果为空,所有变量都会保存和恢复
    reshape=False,
    sharded=False,
    max_to_keep=5,#checkpoint最大保存数量
    keep_checkpoint_every_n_hours=10000.0, #每隔多长时间保存一次checkpoint 
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)
下面是构造一个saver对象的代码示例:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')


# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})


# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})


导出保存MetaGraphDef到指定文件路径:
tf.train.Saver.export_meta_graph(
    filename=None,
    collection_list=None,
    as_text=False,
    export_scope=None,
    clear_devices=False,
    clear_extraneous_savers=False
)

tf.train.Saver.recover_last_checkpoints(checkpoint_paths):训练过程中运行崩溃时,恢复last_checkpoints属性状态

tf.train.Saver.restore(sess,save_path):此方法运行构造函数添加的用于恢复变量的ops。 它需要一个会话,其中启动了待恢复变量关联的graph。 要恢复的变量不必被初始化,因为恢复本身就是初始化变量的一种方式。

tf.train.Saver.save(
    sess,
    save_path,
    global_step=None,#用来创建checkpoint相关文件的名称,一般传入当前训练次数
    latest_filename=None,#最新checkpoint文件名,默认checkpoint
    meta_graph_suffix='meta',#MetaGraphDef文件名后缀,默认.meta
    write_meta_graph=True,#是否保存MetaGraph为.cpkt.meta文件
    write_state=True
)

这个函数是tf.train.Saver类用于保存tensorflow模型的一个最主要的函数。
使用时,先实例化一个tf.train.Saver对象,然后在一个会话(Session)中调用的这个tf.train.Saver对象的Save()方法。
下面是一个示例:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

这个函数默认会保存5个文件:
①checkpoint  保存了目录下所有的cpkt文件列表和最新的cpkt文件,即所有的检查点和最新的检查点
②graph.pbtxt  Protocol buffers的txt格式,里面存储神经网络所有的节点和参数信息。
tf.train.Saver.save()保存的.pbtxt文件存储神经网络所有的节点和参数信息,下面是这个文件中一个段落示例:
node {  
  name: "global_step/Initializer/zeros"  
  op: "Const"  
  attr {  
    key: "_class"  
    value {  
      list {  
        s: "loc:@global_step"  
      }  
    }  
  }  
  attr {  
    key: "_output_shapes"  
    value {  
      list {  
        shape {  
        }  
      }  
    }  
  }  
  attr {  
    key: "dtype"  
    value {  
      type: DT_INT64  
    }  
  }  
  attr {  
    key: "value"  
    value {  
      tensor {  
        dtype: DT_INT64  
        tensor_shape {  
        }  
        int64_val: 0  
      }  
    }  
  }  
}


可以看出,这个段落定义了神经网络中的一个节点,一个神经网络节点包含了以下五个部分的内容:
--name  节点名称,是节点的唯一标识。
--op   节点上所运行的操作,例如"Add", "MatMul","Conv2D"等。
--input  节点的输入,是一个字符串列表,每个字符串是一个其他节点的名称,后面带一个冒号和输出端口号。如:["some_node_name:0", "another_node_name:0"]
--device  节点op所运行的设备,GPU或CPU等。
--attr   节点属性,是一个字典,这些是节点的永久属性,在运行时不改变的事物,例如卷积的过滤器的大小,或常量运算的值。因为可能有许多不同类型的属性值,从字符串到int,到张量值的数组,都有一个单独的protobuf文件来定义保存它们的数据结构。
③model.ckpt.meta 可选参数,默认True,保存了graph的结构
④model.cpkt.index  作用未知
⑤model.cpkt.data  这个文件保存了TensorFlow程序中每一个权重和变量的取值

要恢复一个模型,主要使用tf.train.import_meta_graph()函数载入模型的MetaGraph,这个函数同时返回一个tf.train.Saver对象,调用这个对象的restore()方法,可以恢复模型的参数。
下面是一个示例:


import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))

# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved

# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

上述代码还展示了如何在调用saver.restore()方法后,通过tf.get_default_graph()得到graph对象,使用这个对象的get_tensor_by_name()方法可以得到恢复的任意变量。
这有时候非常有用,比如有时候我们只想从模型中恢复部分网络结构,而重新设计其他结构或者对其他结构进行微调,下面的例子展示了从保存的模型中载入一个vgg模型,并改变输出层的结构的示例:

......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning 
 
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
 
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
 
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

更多的使用场合是迁移学习,利用已经训练好的模型,只导入模型的卷积层特征提取部分,而重新训练全连接层和输出层。这能非常方便的将别人训练好的模型用于测试自己的数据。


5.SaverDef 
tensorflow中的很多类,都有与之对应的*Def类,例如,tf.train.Saver类,与之对应有一个tf.train.SaverDef类,还有tf.MetaGraph,tf.Graph对应的tf.MetaGraphDef,tf.GraphDef等,这些类的定义都是Protocol buffers,即文件名后缀是.proto,官方文档解释是Def类是相应类的Protocol buffers配置类,它是python类与Protocol buffers转换的桥梁。tf.train.Saver类还提供了一个方法tf.train.Saver.as_saver_def()来从一个Saver对象生成SaverDef。

也可以通过调用tf.train.Saver.export_meta_graph()和tf.train.import_meta_graph()单独导出和保存MetaGraph:

①使用tf.train.Saver.export_meta_graph保存model为MetaGraph:

# Build the model
...
with tf.Session() as sess:
  # Use the model
  ...
# Export the model to /tmp/my-model.meta.
meta_graph_def = tf.train.export_meta_graph(filename='/tmp/my-model.meta')


注意:调用tf.train.Saver.save()时,也会自动保存MetaGraph为.cpkt.meta文件


②使用tf.train.Saver.import_meta_graph载入model的MetaGraph:

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  # tf.get_collection() returns a list. In this example we only want the
  # first one.
  train_op = tf.get_collection('train_op')[0]
  for step in xrange(1000000):
    sess.run(train_op)

tf.train.Saver.export_meta_graph不仅可以保存graph相关的所有信息,还可以将你自定义的任意graph保存为.cpkt.meta,然后通过import_meta_graph载入并使用。
下面的例子定义并保存了一个前向传播graph:


# Creates an inference graph.
# Hidden 1
images = tf.constant(1.2, tf.float32, shape=[100, 28])
with tf.name_scope("hidden1"):
  weights = tf.Variable(
      tf.truncated_normal([28, 128],
                          stddev=1.0 / math.sqrt(float(28))),
      name="weights")
  biases = tf.Variable(tf.zeros([128]),
                       name="biases")
  hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
# Hidden 2
with tf.name_scope("hidden2"):
  weights = tf.Variable(
      tf.truncated_normal([128, 32],
                          stddev=1.0 / math.sqrt(float(128))),
      name="weights")
  biases = tf.Variable(tf.zeros([32]),
                       name="biases")
  hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
# Linear
with tf.name_scope("softmax_linear"):
  weights = tf.Variable(
      tf.truncated_normal([32, 10],
                          stddev=1.0 / math.sqrt(float(32))),
      name="weights")
  biases = tf.Variable(tf.zeros([10]),
                       name="biases")
  logits = tf.matmul(hidden2, weights) + biases
  tf.add_to_collection("logits", logits)

init_all_op = tf.global_variables_initializer()

with tf.Session() as sess:
  # Initializes all the variables.
  sess.run(init_all_op)
  # Runs to logit.
  sess.run(logits)
  # Creates a saver.
  saver0 = tf.train.Saver()
  saver0.save(sess, 'my-save-dir/my-model-10000')
  # Generates MetaGraphDef.
  saver0.export_meta_graph('my-save-dir/my-model-10000.meta')

然后可以载入这个graph,加上其他自定义的损失函数和优化器,用于训练。

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  # Addes loss and train.
  labels = tf.constant(0, tf.int32, shape=[100], name="labels")
  batch_size = tf.size(labels)
  labels = tf.expand_dims(labels, 1)
  indices = tf.expand_dims(tf.range(0, batch_size), 1)
  concated = tf.concat([indices, labels], 1)
  onehot_labels = tf.sparse_to_dense(
      concated, tf.stack([batch_size, 10]), 1.0, 0.0)
  logits = tf.get_collection("logits")[0]
  cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
      labels=onehot_labels, logits=logits, name="xentropy")
  loss = tf.reduce_mean(cross_entropy, name="xentropy_mean")


  tf.summary.scalar('loss', loss)
  # Creates the gradient descent optimizer with the given learning rate.
  optimizer = tf.train.GradientDescentOptimizer(0.01)


  # Runs train_op.
  train_op = optimizer.minimize(loss)
  sess.run(train_op)

使用tf.train.Saver.save()会保存运行tensorflow程序所需要的全部信息(graph结构,变量值,检查点列表信息),然而有时候并不需要上述所有信息,例如在测试或者离线预测时,其实我们只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要类似变量初始化,模型保存等辅助节点的信息,另外,将变量取值和图的结构分成不同的文件保存有时候也不方便,尤其是当我们需要将训练好的model从一个平台部署到另外一个平台时,例如从PC端部署到android,解决这个问题分为两种情况:

①如果我们已经有了model的分开保存的文件,可以采用tensorflow安装目录/home/zhaixingzhe/tensorflow/tensorflow/python/tools下的freeze_graph.py脚本提供的方法,将前面保存的cpkt文件和.pb文件(.pbtxt)或者.meta文件统一到一起生成一个单一的文件;

②如果想在保存model时将graph结构,变量值等保存为一个统一的.pb文件,这主要用到tf.graph_util.convert_variables_to_constants()函数用相同值的常量替换图中的所有变量,如果有一个包含变量操作的训练图,可以将它们全部转换为持有相同值的Const操作,这样可以用一个GraphDef文件完全描述网络,并允许删除与加载和保存变量相关的大量操作。

convert_variables_to_constants(
    sess,
    input_graph_def,
    output_node_names,
    variable_names_whitelist=None,
    variable_names_blacklist=None
)

下面是一个示例:

import fully_conected as model
import tensorflow as tf
def export_graph(model_name):
  graph = tf.Graph()
  with graph.as_default():
    input_image = tf.placeholder(tf.float32, shape=[None,28*28], name='inputdata')

    logits = model.inference(input_image)
    y_conv = tf.nn.softmax(logits,name='outputdata')
    restore_saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    latest_ckpt = tf.train.latest_checkpoint('log')
    restore_saver.restore(sess, latest_ckpt)
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), ['outputdata'])


    with tf.gfile.GFile('log/mnist.pb', "wb") as f:  
        f.write(output_graph_def.SerializeToString())  


下面是利用上面保存的model进行测试的一个示例,通过读取.pb文件得到graph_def之后,可以用tf.import_graph_def()函数载入graph_def中定义的graph,并设置成默认graph: 
 

from __future__ import absolute_import, unicode_literals
from datasets_mnist import read_data_sets
import tensorflow as tf

train,validation,test = read_data_sets("datasets/", one_hot=True)

with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    output_graph_path = 'log/mnist.pb'
#    sess.graph.add_to_collection("input", mnist.test.images)

    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")

    with tf.Session() as sess:

        tf.initialize_all_variables().run()
        input_x = sess.graph.get_tensor_by_name("inputdata:0")        

        output = sess.graph.get_tensor_by_name("outputdata:0")

        y_conv_2 = sess.run(output,{input_x:test.images})
        print( "y_conv_2", y_conv_2)

        # Test trained model
        #y__2 = tf.placeholder("float", [None, 10])
        y__2 = test.labels
        correct_prediction_2 = tf.equal(tf.argmax(y_conv_2, 1), tf.argmax(y__2, 1))
        print ("correct_prediction_2", correct_prediction_2 )
        accuracy_2 = tf.reduce_mean(tf.cast(correct_prediction_2, "float"))
        print ("accuracy_2", accuracy_2)

        print ("check accuracy %g" % accuracy_2.eval())

tf.import_graph_def()函数返回来自导入图的Ops和/或Tensor对象的列表,与return_elements中的名称相对应。


import_graph_def(
    graph_def,
    input_map=None,  #将graph_def中的输入名称(如字符串)映射到Tensor对象的字典。 导入图中的张量的值将被重新映射到相应的Tensor值。
    return_elements=None, #包含将作为操作对象返回的graph_def中的操作名称的字符串列表; 和/或将作为Tensor对象返回的graph_def中的张量名称,这在只需要载入graph中部分变量时非常有用。
    name=None,
    op_dict=None,
    producer_op_list=None
)




你可能感兴趣的:(深度学习)