tensorflow 学习笔记10 网络模型的保存与提取

参数的保存与提取关键点就是前后参数的shape,name,dtype都必须一致:

参数的保存:

import  tensorflow as tf
w = tf.Variable(tf.constant(1.0, shape=[1]), name="w")
b = tf.Variable(tf.constant(2.0, shape=[1]), name="b")
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, "Model/model.ckpt") 

参数的提取:

import tensorflow as tf
w = tf.Variable(tf.constant(0.0, shape=[1]), name="w")
b = tf.Variable(tf.constant(0.0, shape=[1]), name="b")
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "./Model/model.ckpt")
    print("w,b:",sess.run(w),sess.run(b))
结果:

你可能感兴趣的:(tensorflow)