TensorFlow中Saver保存读取

"微信公众号"

TensorFlow中Saver保存读取_第1张图片

思考一个问题:

我们搭建好一个神经网络,用大量的数据训练好之后,肯定希望保存神经网络里面的参数,用于下次加载。那我们该怎么做呢?

TensorFlow为我们提供了Saver来保存和加载神经网络的参数。

一、保存

(1)import所需的模块,然后建立神经网络当中的W和b,并初始化变量。

import tensorflow as tf
import numpy as np

# Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="weights")
b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases")

init = tf.global_variables_initializer()

(2)保存时,首先要建立一个tf.train.Saver()用来保存,提取变量。再创建一个名为my_net的文件夹,用这个saver来保存变量到这个目录“my_net/save_net.ckp”。

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess,"F:/my_net/save_net.ckpt")
    print("Save to path:",save_path)

(3)效果图:

TensorFlow中Saver保存读取_第2张图片

(4)给出保存参数的完整代码。

import tensorflow as tf
import numpy as np

# Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="weights")
b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases")

init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess,"F:/my_net/save_net.ckpt")
    print("Save to path:",save_path)

二、提取

(1)提取时,先建立临时的W和b容器。找到文件目录,并用saver.restore()提取变量。

#conding:utf-8
import tensorflow as tf
import numpy as np

# restore variables
# 先建立W,b的容器
# redefine the same shape and same type for your variables
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases")

# not need init step

saver = tf.train.Saver()
with tf.Session() as sess:
    # 提取变量
    saver.restore(sess,"F:/my_net/save_net.ckpt")
    print("weights:",sess.run(W))
    print("biases:",sess.run(b))
观看视频笔记:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-06-save/

你可能感兴趣的:(深度学习,TensorFlow学习笔记,Tensorflow,神经网络,深度学习,机器学习,Saver)