xiaoyao 动手学深度学习 tensorflow 2.1.0
随着深度学习框架的发展,开发深度学习应用变得越来越便利。实践中,我们通常可以用比上一节更简洁的代码来实现同样的模型。在本节中,我们将介绍如何使用tensorflow2.1.0推荐的keras接口更方便地实现线性回归的训练。
我们生成与上一节中相同的数据集。其中features
是训练数据特征,labels
是标签。
import tensorflow as tf
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = tf.random.normal(shape=(num_examples, num_inputs), stddev=1)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += tf.random.normal(labels.shape, stddev=0.01)
虽然tensorflow2.1.0对于线性回归可以直接拟合,不用再划分数据集,但我们仍学习一下读取数据的方法
# shuffle的buffer_size参数应该大于等于样本数,batch可以指定batch_size的分割大小
from tensorflow import data as tfdata
batch_size = 10
# 将训练数据的特征和标签组合
dataset = tfdata.Dataset.from_tensor_slices((features, labels))
# 随机读取小批量
dataset = dataset.shuffle(buffer_size=num_examples)
dataset = dataset.batch(batch_size)
data_iter = iter(dataset)
# 使用iter(dataset)的方式,只能遍历数据集一次,
for X, y in data_iter:
print(X, y)
break
tf.Tensor(
[[-1.825709 -0.41736308]
[-0.6260657 -0.37043497]
[-1.2240889 -0.4179984 ]
[-2.252687 0.32804537]
[ 0.00852243 -0.11625145]
[ 0.42531878 -0.9496812 ]
[ 0.4022167 -0.07259909]
[-1.0691589 -0.18955724]
[-0.20947874 1.566279 ]
[ 1.7726566 1.5784163 ]], shape=(10, 2), dtype=float32) tf.Tensor(
[ 1.9595385 4.213116 3.1751869 -1.417277 4.620007 8.291456
5.2536917 2.709371 -1.5466335 2.3810005], shape=(10,), dtype=float32)
定义模型,tensorflow 2.x推荐使用keras定义网络,故使用keras定义网络我们先定义一个模型变量model
,它是一个Sequential
实例。在keras中,Sequential
实例可以看作是一个串联各个层的容器。
在构造模型时,我们在该容器中依次添加层。当给定输入数据时,容器中的每一层将依次计算并将输出作为下一层的输入。重要的一点是,在keras中我们无须指定每一层输入的形状。
因为为线性回归,输入层与输出层全连接,故定义一层–全连接层keras.layers.Dense()
Keras中初始化参数由kermel_initializer和bias_initializer选项分别设置权重和偏置的初始化方式。
这里从tensorflow导入initializers模块,指定权重参数每个元素将在初始化时随机采样于均值为零、标准差为0.01的正态分布。偏差参数默认初始化为零。
RandomNormal(stddev=0.01)指定权重参数每个元素将在初始化时随机采样于均值为0、标准差为0.01的正态分布。偏差参数默认会初始化为零。
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow import initializers as init
model = keras.Sequential() # 看作是串联各个层的容器
model.add(layers.Dense(1, kernel_initializer=init.RandomNormal(stddev=0.01)))
定义损失函数和优化器:损失函数为mse,优化器选择sgd随机梯度下降
在keras中,定义完模型后,调用compile()
方法可以配置模型的损失函数和优化方法。
定义损失函数只需传入loss
的参数,keras定义了各种损失函数,并直接使用它提供的平方损失mse
作为模型的损失函数。
也无须实现小批量随机梯度下降,只需传入optimizer
的参数,keras定义了各种优化算法,我们这里直接指定学习率为0.03的小批量随机梯度下降tf.keras.optimizers.SGD(0.03)
为优化算法
from tensorflow import losses
loss = losses.MeanSquaredError()
from tensorflow.keras import optimizers
trainer = optimizers.SGD(learning_rate=0.03)
loss_history = []
在使用keras训练模型时,我们通过调用model
实例的fit
函数来迭代模型。fit
函数只需传入你的输入x和输出y,还有epoch遍历数据的次数,每次更新梯度的大小batch_size, 这里定义epoch=3,batch_size=10。
使用keras甚至完全不需要去划分数据集
在使用tensorflow训练模型的时候,通过调用tensorflow.GradientTape记录动态图梯度,执行tape.gradient获得动态图中各变量梯度。
通过model.trainable_varialbes找到需要更新的变量,并使用trainer.apply_gradients更新权重,完成一步训练。
num_epochs = 3
for epoch in range(1, num_epochs + 1):
for (batch, (X, y)) in enumerate(dataset):
with tf.GradientTape() as tape:
l = loss(model(X, training=True), y)
loss_history.append(l.numpy().mean())
grads = tape.gradient(l, model.trainable_variables)
trainer.apply_gradients(zip(grads, model.trainable_variables))
l = loss(model(features), labels)
print('epoch %d, loss: %f' % (epoch, l))
epoch 1, loss: 0.000264
epoch 2, loss: 0.000097
epoch 3, loss: 0.000097
下面我们分别比较学到的模型参数和真实的模型参数。我们可以通过model的get_weights()
来获得其权重(weight
)和偏差(bias
)。学到的参数和真实的参数很接近。
true_w, model.get_weights()[0]
([2, -3.4],
array([[ 1.9998281],
[-3.3996763]], dtype=float32))
true_b, model.get_weights()[1]
(4.2, array([4.1998463], dtype=float32))
loss_history
[29.501896,
37.45631,
12.249255,
23.72889,
23.883945,
25.070272,
14.251101,
8.442382,
24.766382,
9.228335,
8.04291,
8.583006,
6.9523644,
6.9970107,
5.7393394,
6.7562685,
2.006997,
3.3466537,
3.010506,
1.8910837,
2.9811425,
2.9470952,
2.7346947,
2.5683753,
1.0880806,
0.71038055,
1.3765603,
1.2225089,
1.125397,
1.136457,
1.0656222,
0.6368358,
1.0103394,
0.81613255,
0.45046028,
0.633396,
0.2740888,
0.44052514,
0.20187739,
0.23083887,
0.19622864,
0.17404571,
0.15724395,
0.39956665,
0.13184759,
0.13588975,
0.0413301,
0.062211554,
0.09542455,
0.06948571,
0.121049464,
0.1404176,
0.07027206,
0.02035113,
0.10618506,
0.06540239,
0.03850427,
0.044746242,
0.037409224,
0.037087567,
0.013585197,
0.04274003,
0.020035543,
0.014686924,
0.018439168,
0.030150274,
0.023141002,
0.019083317,
0.012115336,
0.012250148,
0.010110767,
0.00612779,
0.0148302885,
0.0054951767,
0.003688395,
0.0063000335,
0.0067952946,
0.0037225746,
0.0011148332,
0.0016755849,
0.002579968,
0.0022298498,
0.0027520158,
0.0021182017,
0.0010050359,
0.0019038839,
0.0011049738,
0.0013840701,
0.0010081959,
0.0004165701,
0.0009860347,
0.00060588756,
0.00046795295,
0.00030214773,
0.0005622429,
0.0006436542,
0.00032493853,
0.00063880545,
0.00042860032,
0.00018070132,
0.00015794327,
0.00017725705,
0.00026884335,
0.00028985454,
0.0001893751,
8.273552e-05,
8.2549916e-05,
0.00013522906,
6.562472e-05,
0.00011805694,
0.00014822869,
0.00018188413,
0.00010688017,
0.00011095459,
0.00019555617,
0.00019057601,
0.0003080869,
7.3299874e-05,
8.4678955e-05,
0.00011555682,
0.00012923064,
7.315063e-05,
5.8265996e-05,
0.00012395837,
0.00013559048,
9.3044066e-05,
8.4587366e-05,
5.7960708e-05,
5.7924295e-05,
0.00012980713,
9.7370845e-05,
6.330477e-05,
0.00010059988,
7.232769e-05,
0.00017936503,
6.452073e-05,
5.009457e-05,
0.00010594791,
0.00012093749,
0.00013548261,
0.000107912696,
0.0001587457,
6.858254e-05,
0.0001724594,
0.00010172928,
7.6469034e-05,
7.6007054e-05,
7.583733e-05,
9.580182e-05,
5.8986305e-05,
5.4275395e-05,
6.976486e-05,
4.3399854e-05,
0.00014459722,
0.00018001617,
0.00013258224,
0.00031393423,
0.00010372,
5.736463e-05,
9.139093e-05,
9.799221e-05,
8.2846906e-05,
9.64843e-05,
0.00014751268,
8.349354e-05,
5.8543672e-05,
0.00012027039,
0.00011267074,
3.542353e-05,
0.00014143434,
0.00012744889,
0.00015769311,
4.4014298e-05,
0.000116863215,
9.867393e-05,
9.499614e-05,
0.000118109936,
4.329575e-05,
7.521584e-05,
0.0001241296,
4.275844e-05,
8.648134e-05,
0.00011301902,
0.000101929276,
0.00010192163,
6.985559e-05,
0.00010751579,
7.195994e-05,
2.9877838e-05,
8.252472e-05,
0.00021170666,
0.000114028866,
4.07525e-05,
0.00011056512,
0.00015362678,
6.4155414e-05,
0.00010491493,
0.000110198525,
0.0001302041,
0.00013186826,
0.00016527154,
0.00015286378,
6.084417e-05,
5.6655193e-05,
4.8877053e-05,
5.363222e-05,
6.1288825e-05,
5.74289e-05,
0.00012154386,
3.2718228e-05,
6.969248e-05,
0.000104646824,
0.00014144731,
6.1936196e-05,
3.7562757e-05,
7.326159e-05,
0.00010985002,
9.588372e-05,
0.00023255777,
0.00011218952,
0.00014342464,
0.00012717072,
3.6798574e-05,
7.485154e-05,
7.93941e-05,
0.0001249698,
0.00019434367,
0.00011884035,
0.00013018816,
6.532644e-05,
6.15924e-05,
8.129996e-05,
0.00012252374,
0.00014110973,
0.00010313366,
4.4449225e-05,
3.055489e-05,
9.272004e-05,
8.4361076e-05,
9.4692965e-05,
0.00012557449,
7.8463054e-05,
0.00012208376,
8.491871e-05,
6.938853e-05,
0.00012711977,
0.00017110733,
0.00029210007,
0.00015827871,
0.0001660751,
9.0286114e-05,
0.000115873314,
0.00013234252,
6.201891e-05,
2.3510238e-05,
5.5823904e-05,
0.00011468558,
6.126233e-05,
0.00015700776,
0.00016621803,
4.3632343e-05,
9.0545145e-05,
0.00014167516,
0.00010468601,
3.7364236e-05,
0.00013142396,
0.00013766726,
9.6800606e-05,
6.343221e-05,
6.1979656e-05,
0.00013079047,
6.305989e-05,
7.536479e-05,
7.072952e-05,
7.8100755e-05,
0.00015733825,
5.7136553e-05,
0.0001431292,
4.0489856e-05,
9.89647e-05,
3.5244804e-05,
5.200087e-05,
6.809345e-05,
7.249845e-05,
7.157237e-05,
4.426187e-05,
7.577443e-05,
0.00016322176,
0.0002448729,
0.00012856603,
7.970275e-05,
8.254009e-05,
9.36201e-05,
2.938913e-05,
3.1724147e-05,
0.00012240729,
0.00010769217,
7.548153e-05,
0.00014087862,
0.00011540208]
使用tensorflow可以简洁的实现模型,tensorflow.data模块提供了有关数据处理的工具,tensorflow.keras.layers模块定义了大量神经网络的层,tensorflow.initializers模块定义了各种初始化方法,tensorflow.optimizers模块提供了模型的各种优化算法。