从TensorFlow模型中提取具体参数值

代码
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

model_dir = “/home/model”

ckpt = tf.train.get_checkpoint_state(model_dir)
ckpt_path = ckpt.model_checkpoint_path

reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
param_dict = reader.get_variable_to_shape_map()
print(param_dict)
print(reader.get_tensor(‘batch_normalization/gamma’).shape)

你可能感兴趣的:(从TensorFlow模型中提取具体参数值)