TensorFlow训练参数存为npy格式并调用——线性回归

模型训练并保存

# -*- coding: utf-8 -*-
"""
Created on Sun Mar 15 10:27:32 2020

@author: weixifei
"""
import tensorflow as tf
import numpy as np

# In[]
x = tf.random_normal([100,1],mean=1.75,stddev=0.5,name="my_data")
y_true = tf.matmul(x,[[0.7]])+0.8  
weight = tf.Variable(tf.random_normal([1,1],mean=0.1,stddev=1.0),name="weight")
bias = tf.Variable(0.1,name="bias")
y_predict = tf.matmul(x,weight)+bias
loss = tf.reduce_mean(tf.square(y_true-y_predict))
train_op=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init_op=tf.global_variables_initializer()

# In[]   


# In[]
with tf.Session() as sess:
    #初始化变量
    sess.run(init_op)
    #打印随机初始化的权重和偏置
    #print("随机初始化的权重参数为:%f,偏置为:%f" %(weight.eval(),bias.eval()))
    #循环训练,运行优化
    for i in range(50):
        sess.run(train_op)
    print("第%d次训练权重参数为:%f,偏置为:%f" %(i,weight.eval(),bias.eval()))
#模型保存
    data = {
     }
    #权值和偏置保存为npy格式,最后一次训练的值
    #每一次重新训练保存需要重启控制台,不然上次的模型参数也会被保存
    for var in tf.trainable_variables():
        print(var.name) #打印变量名字
        data[var.name] = sess.run(var)
    np.save('./out.npy', data) 

加载npy参数文件测试

#加载npy文件测试
data=np.load('./out.npy',allow_pickle=True).item() #读入npy文件
#将字典中的某个值以张量的形式赋给网络中的某个权重和偏置(得知道键)
#trainable决定你是否要固定权重,False代表固定权重
w = tf.Variable(data['weight:0'], dtype=tf.float32, trainable=False)
sess=tf.Session()  
sess.run(tf.global_variables_initializer())
print(sess.run(w))

你可能感兴趣的:(深度学习,深度学习,tensorflow,python)