线性回归tensorflow2.1.0简洁实现

线性回归的简洁实现

xiaoyao 动手学深度学习 tensorflow 2.1.0

随着深度学习框架的发展,开发深度学习应用变得越来越便利。实践中,我们通常可以用比上一节更简洁的代码来实现同样的模型。在本节中,我们将介绍如何使用tensorflow2.1.0推荐的keras接口更方便地实现线性回归的训练。

生成数据集

我们生成与上一节中相同的数据集。其中features是训练数据特征,labels是标签。

import tensorflow as tf

num_inputs = 2
num_examples = 1000

true_w = [2, -3.4]
true_b = 4.2

features = tf.random.normal(shape=(num_examples, num_inputs), stddev=1)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += tf.random.normal(labels.shape, stddev=0.01)

读取数据

虽然tensorflow2.1.0对于线性回归可以直接拟合,不用再划分数据集,但我们仍学习一下读取数据的方法

# shuffle的buffer_size参数应该大于等于样本数,batch可以指定batch_size的分割大小
from tensorflow import data as tfdata

batch_size = 10
# 将训练数据的特征和标签组合
dataset = tfdata.Dataset.from_tensor_slices((features, labels))
# 随机读取小批量
dataset = dataset.shuffle(buffer_size=num_examples)
dataset = dataset.batch(batch_size)
data_iter = iter(dataset)
# 使用iter(dataset)的方式,只能遍历数据集一次,
for X, y in data_iter:
    print(X, y)
    break
tf.Tensor(
[[-1.825709   -0.41736308]
 [-0.6260657  -0.37043497]
 [-1.2240889  -0.4179984 ]
 [-2.252687    0.32804537]
 [ 0.00852243 -0.11625145]
 [ 0.42531878 -0.9496812 ]
 [ 0.4022167  -0.07259909]
 [-1.0691589  -0.18955724]
 [-0.20947874  1.566279  ]
 [ 1.7726566   1.5784163 ]], shape=(10, 2), dtype=float32) tf.Tensor(
[ 1.9595385  4.213116   3.1751869 -1.417277   4.620007   8.291456
  5.2536917  2.709371  -1.5466335  2.3810005], shape=(10,), dtype=float32)

定义模型

定义模型,tensorflow 2.x推荐使用keras定义网络,故使用keras定义网络我们先定义一个模型变量model,它是一个Sequential实例。在keras中,Sequential实例可以看作是一个串联各个层的容器。

在构造模型时,我们在该容器中依次添加层。当给定输入数据时,容器中的每一层将依次计算并将输出作为下一层的输入。重要的一点是,在keras中我们无须指定每一层输入的形状。
因为为线性回归,输入层与输出层全连接,故定义一层–全连接层keras.layers.Dense()

Keras中初始化参数由kermel_initializer和bias_initializer选项分别设置权重和偏置的初始化方式。

这里从tensorflow导入initializers模块,指定权重参数每个元素将在初始化时随机采样于均值为零、标准差为0.01的正态分布。偏差参数默认初始化为零。

RandomNormal(stddev=0.01)指定权重参数每个元素将在初始化时随机采样于均值为0、标准差为0.01的正态分布。偏差参数默认会初始化为零。

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow import initializers as init
model = keras.Sequential() # 看作是串联各个层的容器
model.add(layers.Dense(1, kernel_initializer=init.RandomNormal(stddev=0.01)))

定义损失函数和优化算法

定义损失函数和优化器:损失函数为mse,优化器选择sgd随机梯度下降

在keras中,定义完模型后,调用compile()方法可以配置模型的损失函数和优化方法。

定义损失函数只需传入loss的参数,keras定义了各种损失函数,并直接使用它提供的平方损失mse作为模型的损失函数。

也无须实现小批量随机梯度下降,只需传入optimizer的参数,keras定义了各种优化算法,我们这里直接指定学习率为0.03的小批量随机梯度下降tf.keras.optimizers.SGD(0.03)为优化算法

from tensorflow import losses
loss = losses.MeanSquaredError()
from tensorflow.keras import optimizers
trainer = optimizers.SGD(learning_rate=0.03)
loss_history = []

在使用keras训练模型时,我们通过调用model实例的fit函数来迭代模型。fit函数只需传入你的输入x和输出y,还有epoch遍历数据的次数,每次更新梯度的大小batch_size, 这里定义epoch=3,batch_size=10。
使用keras甚至完全不需要去划分数据集

在使用tensorflow训练模型的时候,通过调用tensorflow.GradientTape记录动态图梯度,执行tape.gradient获得动态图中各变量梯度。

通过model.trainable_varialbes找到需要更新的变量,并使用trainer.apply_gradients更新权重,完成一步训练。

num_epochs = 3
for epoch in range(1, num_epochs + 1):
    for (batch, (X, y)) in enumerate(dataset):
        with tf.GradientTape() as tape:
            l = loss(model(X, training=True), y)
        
        loss_history.append(l.numpy().mean())
        grads = tape.gradient(l, model.trainable_variables)
        trainer.apply_gradients(zip(grads, model.trainable_variables))
    
    l = loss(model(features), labels)
    print('epoch %d, loss: %f' % (epoch, l))
    
epoch 1, loss: 0.000264
epoch 2, loss: 0.000097
epoch 3, loss: 0.000097

下面我们分别比较学到的模型参数和真实的模型参数。我们可以通过model的get_weights()来获得其权重(weight)和偏差(bias)。学到的参数和真实的参数很接近。

true_w, model.get_weights()[0]
([2, -3.4],
 array([[ 1.9998281],
        [-3.3996763]], dtype=float32))
true_b, model.get_weights()[1]
(4.2, array([4.1998463], dtype=float32))
loss_history
[29.501896,
 37.45631,
 12.249255,
 23.72889,
 23.883945,
 25.070272,
 14.251101,
 8.442382,
 24.766382,
 9.228335,
 8.04291,
 8.583006,
 6.9523644,
 6.9970107,
 5.7393394,
 6.7562685,
 2.006997,
 3.3466537,
 3.010506,
 1.8910837,
 2.9811425,
 2.9470952,
 2.7346947,
 2.5683753,
 1.0880806,
 0.71038055,
 1.3765603,
 1.2225089,
 1.125397,
 1.136457,
 1.0656222,
 0.6368358,
 1.0103394,
 0.81613255,
 0.45046028,
 0.633396,
 0.2740888,
 0.44052514,
 0.20187739,
 0.23083887,
 0.19622864,
 0.17404571,
 0.15724395,
 0.39956665,
 0.13184759,
 0.13588975,
 0.0413301,
 0.062211554,
 0.09542455,
 0.06948571,
 0.121049464,
 0.1404176,
 0.07027206,
 0.02035113,
 0.10618506,
 0.06540239,
 0.03850427,
 0.044746242,
 0.037409224,
 0.037087567,
 0.013585197,
 0.04274003,
 0.020035543,
 0.014686924,
 0.018439168,
 0.030150274,
 0.023141002,
 0.019083317,
 0.012115336,
 0.012250148,
 0.010110767,
 0.00612779,
 0.0148302885,
 0.0054951767,
 0.003688395,
 0.0063000335,
 0.0067952946,
 0.0037225746,
 0.0011148332,
 0.0016755849,
 0.002579968,
 0.0022298498,
 0.0027520158,
 0.0021182017,
 0.0010050359,
 0.0019038839,
 0.0011049738,
 0.0013840701,
 0.0010081959,
 0.0004165701,
 0.0009860347,
 0.00060588756,
 0.00046795295,
 0.00030214773,
 0.0005622429,
 0.0006436542,
 0.00032493853,
 0.00063880545,
 0.00042860032,
 0.00018070132,
 0.00015794327,
 0.00017725705,
 0.00026884335,
 0.00028985454,
 0.0001893751,
 8.273552e-05,
 8.2549916e-05,
 0.00013522906,
 6.562472e-05,
 0.00011805694,
 0.00014822869,
 0.00018188413,
 0.00010688017,
 0.00011095459,
 0.00019555617,
 0.00019057601,
 0.0003080869,
 7.3299874e-05,
 8.4678955e-05,
 0.00011555682,
 0.00012923064,
 7.315063e-05,
 5.8265996e-05,
 0.00012395837,
 0.00013559048,
 9.3044066e-05,
 8.4587366e-05,
 5.7960708e-05,
 5.7924295e-05,
 0.00012980713,
 9.7370845e-05,
 6.330477e-05,
 0.00010059988,
 7.232769e-05,
 0.00017936503,
 6.452073e-05,
 5.009457e-05,
 0.00010594791,
 0.00012093749,
 0.00013548261,
 0.000107912696,
 0.0001587457,
 6.858254e-05,
 0.0001724594,
 0.00010172928,
 7.6469034e-05,
 7.6007054e-05,
 7.583733e-05,
 9.580182e-05,
 5.8986305e-05,
 5.4275395e-05,
 6.976486e-05,
 4.3399854e-05,
 0.00014459722,
 0.00018001617,
 0.00013258224,
 0.00031393423,
 0.00010372,
 5.736463e-05,
 9.139093e-05,
 9.799221e-05,
 8.2846906e-05,
 9.64843e-05,
 0.00014751268,
 8.349354e-05,
 5.8543672e-05,
 0.00012027039,
 0.00011267074,
 3.542353e-05,
 0.00014143434,
 0.00012744889,
 0.00015769311,
 4.4014298e-05,
 0.000116863215,
 9.867393e-05,
 9.499614e-05,
 0.000118109936,
 4.329575e-05,
 7.521584e-05,
 0.0001241296,
 4.275844e-05,
 8.648134e-05,
 0.00011301902,
 0.000101929276,
 0.00010192163,
 6.985559e-05,
 0.00010751579,
 7.195994e-05,
 2.9877838e-05,
 8.252472e-05,
 0.00021170666,
 0.000114028866,
 4.07525e-05,
 0.00011056512,
 0.00015362678,
 6.4155414e-05,
 0.00010491493,
 0.000110198525,
 0.0001302041,
 0.00013186826,
 0.00016527154,
 0.00015286378,
 6.084417e-05,
 5.6655193e-05,
 4.8877053e-05,
 5.363222e-05,
 6.1288825e-05,
 5.74289e-05,
 0.00012154386,
 3.2718228e-05,
 6.969248e-05,
 0.000104646824,
 0.00014144731,
 6.1936196e-05,
 3.7562757e-05,
 7.326159e-05,
 0.00010985002,
 9.588372e-05,
 0.00023255777,
 0.00011218952,
 0.00014342464,
 0.00012717072,
 3.6798574e-05,
 7.485154e-05,
 7.93941e-05,
 0.0001249698,
 0.00019434367,
 0.00011884035,
 0.00013018816,
 6.532644e-05,
 6.15924e-05,
 8.129996e-05,
 0.00012252374,
 0.00014110973,
 0.00010313366,
 4.4449225e-05,
 3.055489e-05,
 9.272004e-05,
 8.4361076e-05,
 9.4692965e-05,
 0.00012557449,
 7.8463054e-05,
 0.00012208376,
 8.491871e-05,
 6.938853e-05,
 0.00012711977,
 0.00017110733,
 0.00029210007,
 0.00015827871,
 0.0001660751,
 9.0286114e-05,
 0.000115873314,
 0.00013234252,
 6.201891e-05,
 2.3510238e-05,
 5.5823904e-05,
 0.00011468558,
 6.126233e-05,
 0.00015700776,
 0.00016621803,
 4.3632343e-05,
 9.0545145e-05,
 0.00014167516,
 0.00010468601,
 3.7364236e-05,
 0.00013142396,
 0.00013766726,
 9.6800606e-05,
 6.343221e-05,
 6.1979656e-05,
 0.00013079047,
 6.305989e-05,
 7.536479e-05,
 7.072952e-05,
 7.8100755e-05,
 0.00015733825,
 5.7136553e-05,
 0.0001431292,
 4.0489856e-05,
 9.89647e-05,
 3.5244804e-05,
 5.200087e-05,
 6.809345e-05,
 7.249845e-05,
 7.157237e-05,
 4.426187e-05,
 7.577443e-05,
 0.00016322176,
 0.0002448729,
 0.00012856603,
 7.970275e-05,
 8.254009e-05,
 9.36201e-05,
 2.938913e-05,
 3.1724147e-05,
 0.00012240729,
 0.00010769217,
 7.548153e-05,
 0.00014087862,
 0.00011540208]

使用tensorflow可以简洁的实现模型,tensorflow.data模块提供了有关数据处理的工具,tensorflow.keras.layers模块定义了大量神经网络的层,tensorflow.initializers模块定义了各种初始化方法,tensorflow.optimizers模块提供了模型的各种优化算法。


你可能感兴趣的:(深度学习,tensorflow,神经网络,深度学习,python,线性回归)