将TF保存在checkpoint中的变量值输出到txt或npy中的几种简单的可行的方法:
1,最简单的方法,是在有model 的情况下,直接用tf.train.saver进行restore,就像 cifar10_eval.py 中那样。然后,在sess中直接run变量的名字就可以得到变量保存的值。
with tf.Graph().as_default() as g:
images, labels = cifar10.inputs(eval_data=eval_data)
logits = cifar10.inference(images)
top_k_op = tf.nn.in_top_k(logits, labels, 1)
variable_averages = tf.train.ExponentialMovingAverage(
cifar10.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
for name in variables_to_restore:
print(name)
softmax_linear/biases/ExponentialMovingAverage
conv2/biases/ExponentialMovingAverage
local4/biases/ExponentialMovingAverage
local3/biases/ExponentialMovingAverage
softmax_linear/weights/ExponentialMovingAverage
conv1/biases/ExponentialMovingAverage
local4/weights/ExponentialMovingAverage
local3/weights/ExponentialMovingAverage
conv2/weights/ExponentialMovingAverage
conv1/weights/ExponentialMovingAverage
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
conv1_w=sess.run('conv1/weights/ExponentialMovingAverage')
"""A simple script for inspect checkpoint files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("file_name", "", "Checkpoint filename")
tf.app.flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect")
def print_tensors_in_checkpoint_file(file_name, tensor_name):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
in the checkpoint file.
If `tensor_name` is provided, prints the content of the tensor.
Args:
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
"""
try:
reader = tf.train.NewCheckpointReader(file_name)
if not tensor_name:
print(reader.debug_string().decode("utf-8"))
else:
print("tensor_name: ", tensor_name)
print(reader.get_tensor(tensor_name))
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
def main(unused_argv):
if not FLAGS.file_name:
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
"[--tensor_name=tensor_to_print]")
sys.exit(1)
else:
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name)
if __name__ == "__main__":
tf.app.run()