tensorflow-saver保存读取

1.保存变量到文件中

import tensorflow as tf
import numpy as np
# 当保存变量的时候记得他们设置相同的类型,否则报错
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')

init= tf.initialize_all_variables()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess, "my_net/save_net.ckpt")#my_net提前定义的文件夹,ckpt格式是官方网站介绍的
    print("Save to path: ", save_path)

2.保存变量到文件中

import tensorflow as tf
import numpy as np
# 变量要定义相同的类型
#2行3列,6个数字
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

# 不需要init变量
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "my_net/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))

你可能感兴趣的:(tensorflow)