tensorflow 中导出/恢复模型Graph数据Saver

不得不说,在tensorflow中,这个问题一直困扰我好几天了,没有弄清graph个saver的关系。
下面我就记录一下两者的用法以及应用场景:

Graph

图是tensorflow的核心,所有的操作都是基于图进行的,图中有很多的op,一个op又有一个或则多个的Tensor构成。

Saver

在训练的中可以保存数据比如得到一个Weights值后,需要保存下来,以便下次再使用。

应用场景

graph 和saver可以相互配合使用。可以说graph提供模型,saver提供数据。下面通过训练手写字识别来进行保存graph和saver:

#coding=utf-8
#保存soft.ph和soft.ckpt
#created by tengxing on 2017.2.22
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np

mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)

#create model
with tf.name_scope('input'):
    x = tf.placeholder(tf.float32,[None,784],name='x_input')
    y_ = tf.placeholder(tf.float32,[None,10],name='y_input')
with tf.name_scope('layer'):
    with tf.name_scope('W'):
        #tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
        W = tf.Variable(tf.zeros([784,10]),name='Weights')
    with tf.name_scope('b'):
        b = tf.Variable(tf.zeros([10]),name='biases')
    with tf.name_scope('W_p_b'):
        Wx_plus_b = tf.add(tf.matmul(x, W), b, name='Wx_plus_b')

    y = tf.nn.softmax(Wx_plus_b, name='final_result')
    print y

#define loss and optimizer
with tf.name_scope('loss'):
    loss = -tf.reduce_sum(y_ * tf.log(y))
with tf.name_scope('train_step'):
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    print train_step
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
# important step
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
sess.run(init)
writer = tf.summary.FileWriter("logs/", sess.graph)
#train
for step in range(100):
    batch_xs,batch_ys =mnist.train.next_batch(100)
    train_step.run({x:batch_xs,y_:batch_ys})
    print step
    variables = tf.all_variables()
    saver = tf.train.Saver(variables)
    print len(variables)
    print sess.run(b)
    #print W.get_shape(),b.get_shape()
    saver.save(sess, "data/soft.ckpt")
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})
print '最终的测试正确率:{0}'.format(a)
tf.train.write_graph(sess.graph_def,'graph','soft.ph',False)

通过以上就可以保存起来了,我的代码可能有点乱,自行整理吧,下面开始恢复

#coding=utf-8
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)

# 加载Graph
def loadGraph(dir):
    f = tf.gfile.FastGFile(dir,'rb')
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    persisted_graph =tf.import_graph_def(graph_def,name='')
    return persisted_graph

graph = loadGraph('graph/soft.ph')


with tf.Session(graph=graph) as sess:
    #sess.run(tf.initialize_all_variables())
    #sess.run(init) #加载时候不需要进行初始化
    softmax_tensor = sess.graph.get_tensor_by_name('layer/final_result:0')
    x = sess.graph.get_tensor_by_name('input/x_input:0')
    y_ = sess.graph.get_tensor_by_name('input/y_input:0')
    name = sess.graph.get_tensor_by_name('tengxing:0')
    Weights = sess.graph.get_tensor_by_name('layer/W/Weights:0')
    biases = sess.graph.get_tensor_by_name('layer/b/biases:0')

    #W = tf.Variable(tf.zeros([784, 10]), name='Weights')
    #b = tf.Variable(tf.zeros([10]), name='biases')
    tf.add_to_collection(tf.GraphKeys.VARIABLES, name)
    tf.add_to_collection(tf.GraphKeys.VARIABLES, Weights)
    tf.add_to_collection(tf.GraphKeys.VARIABLES, biases)
    try:
        saver = tf.train.Saver(tf.global_variables())  # 'Saver' misnomer! Better: Persister!
    except:
        pass
    print("load data")
    #print sess.run(name) 此时才有一个Tensor获取变量还要进行赋值
    saver.restore(sess, "./data/soft.ckpt")  # now OK creted by tengxing
    #test
    correct_prediction = tf.equal(tf.argmax(softmax_tensor, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

通过以上两个代码可以实现训练模型的保存和继续使用。大家使用时候有问题发我邮箱:[email protected]
后记:这篇文章写的时间不多,但是确实解决了我的很多问题,我相信这这种问题使我们在开发过程必然面临的。所以我才会花时间取解决。总的来说,结果还是令人满意的,毕竟弄出来了,但是我的代码比较乱,稍后我会整理上传。

你可能感兴趣的:(TensorFlow系列)