在模型保存为model.ckpt时,生成了以下文件,其中的checkpoint文件、meta文件都能用来读取变量
在对话sess中使用tf.train.saver.save(sess,save_path)进行保存
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
network_shape=[1,5,10,1]
learning_rate=0.1
display_step=500
num_steps=1000
x_dot=np.linspace(1,2,300,dtype=np.float32)[:,np.newaxis]
y_dot=2*np.power(x_dot,3)+np.power(x_dot,2)+np.random.normal(0,0.5,x_dot.shape)
X_p=tf.placeholder(dtype=tf.float32,shape=[None,network_shape[0]],name="input")
Y_p=tf.placeholder(dtype=tf.float32,shape=[None,network_shape[-1]],name="output")
w={"w1":tf.Variable(tf.random_normal([network_shape[0],network_shape[1]]),name='w1'),
"w2":tf.Variable(tf.random_normal([network_shape[1],network_shape[2]]),name='w2'),
"out":tf.Variable(tf.random_normal([network_shape[2],network_shape[3]]),name='out')}
b={"b1":tf.Variable(tf.random_normal([network_shape[1]]),name='b1'),
"b2": tf.Variable(tf.random_normal([network_shape[2]]),name='b2'),
"out": tf.Variable(tf.random_normal([network_shape[3]]),name='out')}
def network(x):
with tf.name_scope('layer_1'):
layer1=tf.nn.relu(tf.matmul(x,w['w1'])+b['b1'])
with tf.name_scope('layer_2'):
layer2=tf.nn.relu(tf.matmul(layer1,w['w2'])+b['b2'])
with tf.name_scope('out'):
output=tf.matmul(layer2,w['out'])+b['out']
return output
prediction=network(X_p)
loss = tf.reduce_mean(tf.reduce_sum(tf.square(Y_p-prediction), reduction_indices=[1]))
train_step=tf.train.AdamOptimizer(learning_rate).minimize(loss)
saver=tf.train.Saver()
init=tf.global_variables_initializer()
with tf.Session()as sess:
sess.run(init)
Plt=plt.figure().add_subplot(1, 1, 1)
Plt.scatter(x_dot,y_dot)
plt.ion()#使matplotlib的显示模式转换为交互(interactive)模式。即使在脚本中遇到plt.show(),代码还是会继续执行
plt.show()
for i in range(1,num_steps+1):
_,Loss=sess.run([train_step,loss], feed_dict={X_p: x_dot, Y_p: y_dot})
if i%display_step ==0 or i ==1:
print("echo : ",i,"loss = ",Loss)
prediction_value=sess.run(prediction,feed_dict={X_p:x_dot})#shape=(300,1)
if i !=1:
Plt.lines.remove(lines[0])#删去上次画的图
# try:
# Plt.lines.remove(lines[0])
# except Exception:
# pass
lines=Plt.plot(x_dot,prediction_value)#
plt.pause(1)# 为防止matplotlib画图过快,画完图后自动关闭图像窗口
saver.save(sess=sess,save_path='./ckpt_files/model.ckpt')
tf.summary.FileWriter('./log',tf.get_default_graph())
# plt.waitforbuttonpress()#使最后一张图打开状态,不马上结束程序运行
在定义命名空间时,使用with tf.name_scope('namescope'):
保存events文件时,使用tf.summary.FileWriter(log_dir,tf.get_default_graph())
获得浏览器地址时,使用tensorboard --logdir XXX
有多种方法可以restore保存的变量的数据:
方法一:使用变量名获得变量
需要知道模型在训练的时候是如何定义的,在取出时也定义一个同样大小类型的变量,restore之后run变量
方法二:使用meta图文件
可对图进行操作,restore之后利用
方法三:使用checkpoint文件
可reader这个checkpoint文件,在restore之后通过tensor名称获得变量,这个方法可以获得检查点中所有的变量名
##########################模型的恢复(一):利用变量名############
import tensorflow as tf
network_shape=[1,5,10,1]
date=tf.Variable(initial_value=tf.random_normal([network_shape[0],network_shape[1]]),dtype=tf.float32,name='w1')
saver=tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess=sess,save_path='./ckpt_files/model.ckpt')
da=sess.run(date)
print(da)
##########################模型的恢复(二):利用meta文件############
# import tensorflow as tf
# saver=tf.train.import_meta_graph(meta_graph_or_file='./ckpt_files/model.ckpt.meta')
# with tf.Session() as sess:
# saver.restore(sess,save_path='./ckpt_files/model.ckpt')
# graph=tf.get_default_graph()
# da=graph.get_tensor_by_name(name='w1:0')# Tensor names must be of the form ":"
# date=sess.run(da)
# print(date)
##########################模型的恢复(三):利用checkpoint文件############
# from tensorflow.python import pywrap_tensorflow
# checkpoint_path = './ckpt_files/model.ckpt'
# reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# var_to_shape_map = reader.get_variable_to_shape_map()
# print(reader.get_tensor('w1'))
# print(var_to_shape_map)
# for key in var_to_shape_map:
# print("tensor_name: ", key)
# print(reader.get_tensor(key))