TF 低价API之手动训练回归模型

TensorFlow 低阶API,手动训练一个小型回归模型。

定义数据

我们首先来定义一些输入值 x,以及每个输入值的预期输出值 y_true:

x = tf.constant([[1], [2], [3], [4]], dtype=tf.float32)
y_true = tf.constant([[0], [-1], [-2], [-3]], dtype=tf.float32)

定义模型

接下来,建立一个简单的线性模型,其输出值只有 1 个:

linear_model = tf.layers.Dense(units=1)

y_pred = linear_model(x)

您可以如下评估预测值:

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

print(sess.run(y_pred))

该模型尚未接受训练,因此这里的‘预测’值并不理想。

损失

要优化模型,您首先需要定义损失,我们将使用均方误差,这是回归问题的标准损失。

loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)

print(sess.run(loss))

这里会生成一个损失值:

2.23962

训练

TensorFlow 提供了执行标准优化算法的优化器。这些优化器被实现为 tf.train.Optimizer 的子类。它们会逐渐改变每个变量,以便将损失最小化。最简单的优化算法是梯度下降法,由 tf.train.GradientDescentOptimizer 实现。它会根据损失相对于变量的导数大小来修改各个变量。例如:

optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)

该代码构建了优化所需的所有图组件,并返回一个训练指令。该训练指令在运行时会更新图中的变量。您可以按以下方式运行该指令:

for i in range(100):
    _, loss_value = sess.run((train, loss))
    print(loss_value)

由于 train 是一个指令而不是张量,因此它在运行时不会返回一个值。为了查看训练期间损失的进展,我们会同时运行损失张量,生成如下所示的输出值:

1.35659
1.00412
0.759167
0.588829
0.470264
0.387626
0.329918
0.289511
0.261112
0.241046
...

完整程序如下:

from __future__ import absolute_import, division, print_function
import tensorflow as tf

# TF 手动训练一个小型回归模型

# 1.  定义数据
x = tf.constant([[1], [2], [3], [4]],dtype = tf.float32)
y_true = tf.constant([[0], [-1], [-2], [-3]],dtype = tf.float32)

# 2.  定义模型
linear_model = tf.layers.Dense(units = 1)   # 定义一个简单的线性模型,只有1个输出值
y_pred = linear_model(x)

# 3.  损失
loss = tf.losses.mean_squared_error(labels = y_true, predictions = y_pred)

# 4.  训练
optimizer = tf.train.GradientDescentOptimizer(0.01)  # 学习率为0.01的梯度下降优化器
train = optimizer.minimize(loss)

init = tf.global_variables_initializer()  # 层包含的变量必须先初始化,然后才能使用
sess = tf.Session()   # 创建会话:要评估张量,需要实例化一个 tf.Session 对象
sess.run(init)
for i in range(100):
    _, loss_value = sess.run((train, loss))  # 执行层
    print('Loss at step {} : {:.3f}'.format(i,loss_value))

print(sess.run(y_pred))

你可能感兴趣的:(TF 低价API之手动训练回归模型)