Tensorflow———模型持久化(保存与使用)

TensorFlowt提供了一个非常简单的接口来保存和还原一个神经网络的模型,即tf.train.Saverl类。

 

具体实例:

import tensorflow as tf
from numpy.random import RandomState
import os

# 设置标志位  来操作是训练还是进行模型的使用
isTrain = False


# 除错误外 其他信息不显示
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# 一次喂给神经网络多少数据
batch_size = 8

# 设定两个输入节点
x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input' )

# 设定一个输出点
y_ = tf.placeholder(tf.float32,shape=(None, 1), name='y_input')


w1 = tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1)

loss_less = 10
loss_more = 1

# 定义损失函数 后向传播方法
# tf.reduce_sum()为求和函数
# 自定义的损失函数 用tf.where 代替 tf.select
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y-y_)*loss_more, (y_-y)*loss_less))


# tf.train.AdagradOptimizer(0.001) 构造一个使用Adadelta算法的优化器,学习率为0.001 即梯度下降中的步长为0.001
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)


# 通过随机数生成一个模拟的数据集
# rdm.rand(data_size,2) 产生128行 2列的伪随机数(0 1之间)
rdm = RandomState(1)
data_size = 128
X = rdm.rand(data_size, 2)

#rdm.rand()/10.0 是(0,1)随机数
Y = [[x1 + x2 + rdm.rand()/10.0-0.05] for (x1, x2) in X]

saver = tf.train.Saver()

# 创建会话 开始训练神经网络
with tf.Session() as sees:
    if isTrain:
        # tf.global_variables_initializer()返回一个计算图中所有变量的对象
        init_op = tf.global_variables_initializer()
        sees.run(init_op)

        # 设置训练次数为5000
        STEPS = 5000
        for i in range(STEPS):
            start = (i*batch_size) % data_size
            end = min(start+batch_size, data_size)
            sees.run(train_step,  feed_dict={x: X[start:end], y_: Y[start:end]})
            print("第"+str(i+1)+"步:")
            print(sees.run(w1))
        # 模型的保存
        model_path="./One_node/model.ckpt"
        save_path =saver.save(sees,model_path)
    else:
        # 模型的使用
        # 复原模型到会话中
        saver.restore(sees, "./One_node/model.ckpt")
        # 得到权重
        print(sees.run(w1))

        # 喂入数据,得出结果
        data = rdm.rand(1, 2)
        print("data:",end="")
        print(data)
        result = sees.run(y, feed_dict={x: data})
        print(result)

 

你可能感兴趣的:(Tensorflow)