tensorflow学习——tf.get_collection(), tf.identity()

0、将神经网络生成pb文件,测试程序
以下是程序的关键代码,详细见连接

# 从训练好的ckpt中,导出pb文件
import fully_conected as model
import tensorflow as tf
def export_graph(model_name):
  graph = tf.Graph()
  with graph.as_default():
    input_image = tf.placeholder(tf.float32, shape=[None,28*28], name='inputdata')
# 需要重写一下网络
    logits = model.inference(input_image)
    y_conv = tf.nn.softmax(logits,name='outputdata')
    restore_saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    latest_ckpt = tf.train.latest_checkpoint('log')
    restore_saver.restore(sess, latest_ckpt)
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), ['outputdata'])

#    tf.train.write_graph(output_graph_def, 'log', model_name, as_text=False)
    with tf.gfile.GFile('log/mnist.pb', "wb") as f:  
        f.write(output_graph_def.SerializeToString())  
export_graph('mnist.pb')
# 测试调用保存的pb 文件
from __future__ import absolute_import, unicode_literals
from datasets_mnist import read_data_sets
import tensorflow as tf

train,validation,test = read_data_sets("datasets/", one_hot=True)

with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    output_graph_path = 'log/mnist.pb'
#    sess.graph.add_to_collection("input", mnist.test.images)

    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")

    with tf.Session() as sess:

        tf.initialize_all_variables().run()
        input_x = sess.graph.get_tensor_by_name("inputdata:0")        

        output = sess.graph.get_tensor_by_name("outputdata:0")


        y_conv_2 = sess.run(output,{input_x:test.images})
        print( "y_conv_2", y_conv_2)

        # Test trained model
        #y__2 = tf.placeholder("float", [None, 10])
        y__2 = test.labels
        correct_prediction_2 = tf.equal(tf.argmax(y_conv_2, 1), tf.argmax(y__2, 1))
        print ("correct_prediction_2", correct_prediction_2 )
        accuracy_2 = tf.reduce_mean(tf.cast(correct_prediction_2, "float"))
        print ("accuracy_2", accuracy_2)

        print ("check accuracy %g" % accuracy_2.eval())

1、tf.get_collection获取训练变量

#    train_vars=tf.trainable_variables()
#    g_vars=[var for var in train_vars if var.name.startswith('generator')]
#    d_vars=[var for var in train_vars if var.name.startswith('discriminator')]
    g_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
    d_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')

2、 tf.identity()

import tensorflow as tf
x = tf.Variable(1.0)
x_plus_1 = tf.assign_add(x, 1)

with tf.control_dependencies([x_plus_1]):
    y = x
    z=tf.identity(x,name='x')
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for i in range(5):
        print(sess.run(z))

输出是:2,3,4,5,6

import tensorflow as tf
x = tf.Variable(1.0)
x_plus_1 = tf.assign_add(x, 1)

with tf.control_dependencies([x_plus_1]):
    y = x
    z=tf.identity(x,name='x')
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for i in range(5):
        print(sess.run(y))

输出是:1,1,1,1,1

你可能感兴趣的:(tensorflow学习——tf.get_collection(), tf.identity())