Tensorflow学习笔记--模型保存与调取

注:本文主要通过莫烦的python学习视频记录的内容,如果喜欢请支持莫烦python。谢谢


目前tf的模型保存其实只是参数保存,所以保存文件时你特别要主要以下几点:

1、一定要设定好参数的数据类型!

2、设定参数的名称,并且一一对应!

3、读取参数时,需要设定好模型图!


下面做一个简单的demo,供各位参考:

保存模型:

import tensorflow as tf
import numpy as np

## Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[2,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[2,2,3]], dtype=tf.float32, name='biases')

init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess, "my_net/save_net2.ckpt")
    print("Save to path: ", save_path)

提取模型:

import tensorflow as tf
import numpy as np
# 先建立 W, b 的容器
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
#init = tf.global_variables_initializer()

saver = tf.train.Saver()
with tf.Session() as sess:
    #sess.run(init)
    # 提取变量
    saver.restore(sess, "my_net/save_net2.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))

PS:

如果你读取2次模型变量,你会发现以下错误:

NotFoundError (see above for traceback): Key weights_2 not found in checkpoint
	 [[Node: save_3/RestoreV2_11 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_3/Const_0, save_3/RestoreV2_11/tensor_names, save_3/RestoreV2_11/shape_and_slices)]]

原因:你可以看到再次读取模型时,权重名称已经变成了weights_2。所以就会报错误。


你可能感兴趣的:(深度学习)