tensorflow版本为1.4.1
tensorflow提供了Saver类用于模型的保存与导入。该类定义在tensorflow/python/training/saver.py.中。
Saver类的默认初始化函数如下:
__init__(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
由于该初始化函数均有缺省值,因此我们常用的创建一个Saver对象的操作为tf.train.Saver()
下面解释一下常用的参数:
保存模型需要用到save函数:
save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True
)
下面介绍其中的参数:
加载模型的函数为:
restore(
sess,
save_path
)
下面给一个模型保存的例子,例子还是使用 “单层感知机实现mnist数字分类”
# -*- coding: utf-8 -*-
import tensorflow as tf
from input_data import read_data_sets
import os
# don't show INFO
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
# read mnist
mnist = read_data_sets('MNIST_data', one_hot=True)
# single layer perceptron: y = wx + b
# input
x = tf.placeholder(tf.float32, [None, 784])
# weights
W = tf.Variable(tf.random_normal([784,10], stddev=0.1))
# bias
b = tf.Variable(tf.zeros([10]))
# softmax
y = tf.nn.softmax(tf.matmul(x,W) + b)
# output
y_ = tf.placeholder(tf.float32, [None, 10])
# cross_entropy loss
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
# optimization with gradient descend, the learning rate is set as 0.02
train_step = tf.train.GradientDescentOptimizer(0.02).minimize(cross_entropy)
# initalize all variables
init = tf.global_variables_initializer()
# start a new session
sess = tf.Session()
sess.run(init)
m_saver = tf.train.Saver()
# 2000 iterations
for i in range(2000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
if i % 100 == 0:
m_saver.save(sess, './model/mnist_slp', global_step=i)
# computer the accuracy
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
# close the session
sess.close()
最后可以在model文件下看到保存的文件:
说明:
当我们加载模型时,如下:
# -*- coding: utf-8 -*-
import tensorflow as tf
from input_data import read_data_sets
import os
# don't show INFO
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
# read mnist
mnist = read_data_sets('MNIST_data', one_hot=True)
# single layer perceptron: y = wx + b
# input
x = tf.placeholder(tf.float32, [None, 784])
# weights
W = tf.Variable(tf.random_normal([784,10], stddev=0.1))
# bias
b = tf.Variable(tf.zeros([10]))
# softmax
y = tf.nn.softmax(tf.matmul(x,W) + b)
# output
y_ = tf.placeholder(tf.float32, [None, 10])
# cross_entropy loss
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
# optimization with gradient descend, the learning rate is set as 0.02
train_step = tf.train.GradientDescentOptimizer(0.02).minimize(cross_entropy)
# start a new session
sess = tf.Session()
m_saver = tf.train.Saver()
# load the model
m_saver.restore(sess, './model/mnist_slp-1900')
# computer the accuracy
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
# close the session
sess.close()