1.训练模型的迭代方法
首先对权重w和偏差b进行初始猜测,然后反复调整这些猜测。直到获得损失可能最低的权重和偏差为止。
2.收敛
#在Jupyter中,用matplotlib显示图像需要设置为inline模式,否则不会出现图像
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
# 设置随机种子
np.random.seed(5)
# 直接采用np生成等差数列的方法,生成100个点,每个点的取值在-1~1之间
x_data=np.linspace(-1,1,100)
# y=2x+1+噪声,其中,噪声的维度与x_data一致
y_data=2*x_data+1.0+np.random.randn(*x_data.shape)*0.4
# 画出随机生成数据的散点图
plt.scatter(x_data,y_data)
# 画出想要学习到的线性函数
plt.plot(x_data,2*x_data+1.0,color='red',linewidth=3)
# 定义训练数据的占位符,x是特征值,y是标签值
x=tf.placeholder("float",name='x')
y=tf.placeholder('float',name='y')
def model(x,w,b):
return tf.multiply(x,w)+b
# 创建变量,斜率w,截距b,pred预测值
w=tf.Variable(1.0,name='w0')
b=tf.Variable(0.0,name='b0')
pred=model(x,w,b)
# 设置训练参数
# 迭代次数(训练轮数)
train_epochs=10
#学习率
learning_rate=0.05
# 定义损失函数
# 采用均方差作为损失函数
loss_function = tf.reduce_mean(tf.pow((y-pred),2))
# 梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
# 创建会话和初始化
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
# 训练
step=0
loss_list=[] #用于保存loss的列表,之后显示loss
for epoch in range(train_epochs):
for xs,ys in zip(x_data, y_data):
_, loss=sess.run([optimizer,loss_function], feed_dict={x: xs, y: ys})
# display_step:控制报告的粒度,即每训练几次输出一次损失值
# 与超参数不同,修改display_step不会改变模型所学习的规律
loss_list.append(loss)
step=step+1
display_step=10
if step%display_step==0:
print("Train Epoch:",'%02d'% (epoch+1),"Step:%03d"%(step),"loss=",\
"{:.9f}".format(loss))
b0temp=b.eval(session=sess)
w0temp=w.eval(session=sess)
plt.plot (x_data, b0temp + w0temp * x_data )# 画图
print ("w:", sess.run(w)) # w的值应该在2附近
print ("b:", sess.run(b)) # b的值应该在1附近
# 从上图可以看出,由于本案例所拟合的模型较简单,训练3次之后已经接近收敛。
# 对于复杂模型,需要更多次训练才能收敛。
# 模型可视化
plt.scatter(x_data,y_data,label='Original data')
plt.plot (x_data, x_data * sess.run(w) + sess.run(b),label='Fitted line')
plt.legend(loc=2)# 通过参数loc指定图例位置
# 预测
x_test = 12.0
output = sess.run(w) * x_test + sess.run(b)
print("预测值:%f" % output)
target = 2 * x_test + 1.0
print("目标值:%f" % target)
# 图形化显示损失值
plt.subplot(1,2,1)
plt.plot(loss_list)
plt.subplot(1,2,2)
plt.plot(loss_list,'r+')