"微信公众号"
思考一个问题:
我们搭建好一个神经网络,用大量的数据训练好之后,肯定希望保存神经网络里面的参数,用于下次加载。那我们该怎么做呢?
TensorFlow为我们提供了Saver来保存和加载神经网络的参数。
一、保存
(1)import所需的模块,然后建立神经网络当中的W和b,并初始化变量。
import tensorflow as tf
import numpy as np
# Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="weights")
b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases")
init = tf.global_variables_initializer()
(2)保存时,首先要建立一个tf.train.Saver()用来保存,提取变量。再创建一个名为my_net的文件夹,用这个saver来保存变量到这个目录“my_net/save_net.ckp”。
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,"F:/my_net/save_net.ckpt")
print("Save to path:",save_path)
(3)效果图:
(4)给出保存参数的完整代码。
import tensorflow as tf
import numpy as np
# Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name="weights")
b = tf.Variable([[1,2,3]],dtype=tf.float32,name="biases")
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,"F:/my_net/save_net.ckpt")
print("Save to path:",save_path)
二、提取
(1)提取时,先建立临时的W和b容器。找到文件目录,并用saver.restore()提取变量。
#conding:utf-8
import tensorflow as tf
import numpy as np
# restore variables
# 先建立W,b的容器
# redefine the same shape and same type for your variables
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases")
# not need init step
saver = tf.train.Saver()
with tf.Session() as sess:
# 提取变量
saver.restore(sess,"F:/my_net/save_net.ckpt")
print("weights:",sess.run(W))
print("biases:",sess.run(b))
观看视频笔记:https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-06-save/