tensorflow(四)实战——基于全连接网络的模型保存,读取,tensorboard可视化

一、简要说明

在模型保存为model.ckpt时,生成了以下文件,其中的checkpoint文件、meta文件都能用来读取变量

tensorflow(四)实战——基于全连接网络的模型保存,读取,tensorboard可视化_第1张图片

二、模型保存

在对话sess中使用tf.train.saver.save(sess,save_path)进行保存

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

network_shape=[1,5,10,1]
learning_rate=0.1
display_step=500
num_steps=1000

x_dot=np.linspace(1,2,300,dtype=np.float32)[:,np.newaxis]
y_dot=2*np.power(x_dot,3)+np.power(x_dot,2)+np.random.normal(0,0.5,x_dot.shape)

X_p=tf.placeholder(dtype=tf.float32,shape=[None,network_shape[0]],name="input")
Y_p=tf.placeholder(dtype=tf.float32,shape=[None,network_shape[-1]],name="output")


w={"w1":tf.Variable(tf.random_normal([network_shape[0],network_shape[1]]),name='w1'),
   "w2":tf.Variable(tf.random_normal([network_shape[1],network_shape[2]]),name='w2'),
   "out":tf.Variable(tf.random_normal([network_shape[2],network_shape[3]]),name='out')}

b={"b1":tf.Variable(tf.random_normal([network_shape[1]]),name='b1'),
   "b2": tf.Variable(tf.random_normal([network_shape[2]]),name='b2'),
   "out": tf.Variable(tf.random_normal([network_shape[3]]),name='out')}


def network(x):
    with tf.name_scope('layer_1'):
        layer1=tf.nn.relu(tf.matmul(x,w['w1'])+b['b1'])
    with tf.name_scope('layer_2'):
        layer2=tf.nn.relu(tf.matmul(layer1,w['w2'])+b['b2'])
    with tf.name_scope('out'):
        output=tf.matmul(layer2,w['out'])+b['out']
    return output
prediction=network(X_p)

loss = tf.reduce_mean(tf.reduce_sum(tf.square(Y_p-prediction), reduction_indices=[1]))

train_step=tf.train.AdamOptimizer(learning_rate).minimize(loss)
saver=tf.train.Saver()
init=tf.global_variables_initializer()
with tf.Session()as sess:
    sess.run(init)
    Plt=plt.figure().add_subplot(1, 1, 1)
    Plt.scatter(x_dot,y_dot)
    plt.ion()#使matplotlib的显示模式转换为交互(interactive)模式。即使在脚本中遇到plt.show(),代码还是会继续执行
    plt.show()
    for i in range(1,num_steps+1):
        _,Loss=sess.run([train_step,loss], feed_dict={X_p: x_dot, Y_p: y_dot})
        if i%display_step ==0 or i ==1:
            print("echo : ",i,"loss = ",Loss)
            prediction_value=sess.run(prediction,feed_dict={X_p:x_dot})#shape=(300,1)
            if i !=1:
                Plt.lines.remove(lines[0])#删去上次画的图
            # try:
            #     Plt.lines.remove(lines[0])
            # except Exception:
            #     pass
            lines=Plt.plot(x_dot,prediction_value)#
            plt.pause(1)# 为防止matplotlib画图过快,画完图后自动关闭图像窗口
    saver.save(sess=sess,save_path='./ckpt_files/model.ckpt')
    tf.summary.FileWriter('./log',tf.get_default_graph())

    # plt.waitforbuttonpress()#使最后一张图打开状态,不马上结束程序运行

三、可视化

在定义命名空间时,使用with tf.name_scope('namescope'): 

保存events文件时,使用tf.summary.FileWriter(log_dir,tf.get_default_graph())

获得浏览器地址时,使用tensorboard --logdir XXX

四、读取保存的ckpt文件

有多种方法可以restore保存的变量的数据:

方法一:使用变量名获得变量

 需要知道模型在训练的时候是如何定义的,在取出时也定义一个同样大小类型的变量,restore之后run变量

方法二:使用meta图文件

可对图进行操作,restore之后利用:取出tensor,run这个tensor就能获得变量

方法三:使用checkpoint文件

可reader这个checkpoint文件,在restore之后通过tensor名称获得变量,这个方法可以获得检查点中所有的变量名

##########################模型的恢复(一):利用变量名############
import tensorflow as tf
network_shape=[1,5,10,1]
date=tf.Variable(initial_value=tf.random_normal([network_shape[0],network_shape[1]]),dtype=tf.float32,name='w1')
saver=tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess=sess,save_path='./ckpt_files/model.ckpt')
    da=sess.run(date)
    print(da)


##########################模型的恢复(二):利用meta文件############
# import tensorflow as tf
# saver=tf.train.import_meta_graph(meta_graph_or_file='./ckpt_files/model.ckpt.meta')
# with tf.Session() as sess:
#     saver.restore(sess,save_path='./ckpt_files/model.ckpt')
#     graph=tf.get_default_graph()
#     da=graph.get_tensor_by_name(name='w1:0')# Tensor names must be of the form ":"
#     date=sess.run(da)
#     print(date)

##########################模型的恢复(三):利用checkpoint文件############
# from tensorflow.python import pywrap_tensorflow
# checkpoint_path = './ckpt_files/model.ckpt'
# reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# var_to_shape_map = reader.get_variable_to_shape_map()
# print(reader.get_tensor('w1'))
# print(var_to_shape_map)
# for key in var_to_shape_map:
#     print("tensor_name: ", key)
#     print(reader.get_tensor(key))

 

参考链接:

 

 

你可能感兴趣的:(tensorflow)