tensorflow模型的save与restore,以及checkpoint中读取变量

    创建一个NN

import tensorflow as tf
import numpy as np
#fake data
x = np.linspace(-1, 1, 100)[:, np.newaxis]  #shape(100,1)
noise = np.random.normal(0, 0.1, size=x.shape)
y = np.power(x, 2) + noise      #shape(100,1) + noise
tf_x = tf.placeholder(tf.float32, x.shape)  #input x
tf_y = tf.placeholder(tf.float32, y.shape)  #output y
l = tf.layers.dense(tf_x, 10, tf.nn.relu)   #hidden layer
o = tf.layers.dense(l, 1)                   #output layer
loss = tf.losses.mean_squared_error(tf_y, o )   #compute loss
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss)

    1.使用save对模型进行保存  

sess= tf.Session()
sess.run(tf.global_variables_initializer())     #initialize var in graph
saver = tf.train.Saver()    # define a saver for saving and restoring
for step in range(100):         #train
    sess.run(train_op,{tf_x:x, tf_y:y})
saver.save(sess, 'params/params.ckpt', write_meta_graph=False)  # mate_graph is not recommend

    生成三个文件,分别是checkpoint,.ckpt.data-00000-of-00001,.ckpt.index

    2.使用restore对提取模型

    在提取模型时,需要将模型结构再定义一遍,再将各参数加载出来

#bulid entire net again and restore
tf_x = tf.placeholder(tf.float32, x.shape)
tf_y = tf.placeholder(tf.float32, y.shape)
l_ = tf.layers.dense(tf_x, 10, tf.nn.relu)
o_ = tf.layers.dense(l_, 1)
loss_ = tf.losses.mean_squared_error(tf_y, o_)

sess = tf.Session()
# don't need to initialize variables, just restoring trained variables
saver = tf.train.Saver()    # define a saver for saving and restoring
saver.restore(sess, './params/params.ckpt')

    3.有时会报错Not found:b1 not found in checkpoint

    这时我们想知道我在文件中到底保存了什么内容,即需要读取出checkpoint中的tensor

import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('params','params.ckpt')
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and value
f = open('params.txt','w')
for key in var_to_shape_map:  # write tensors' names and values in file
    print(key,file=f)
    print(reader.get_tensor(key),file=f)
f.close()

   运行后生成一个params.txt文件,在其中可以看到模型的参数。


引用博客:Tensorflow: 从checkpoint文件中读取tensor


你可能感兴趣的:(tensorflow,save,restore,checkpoint,tensorflow)