TensorFlow:模型装配、训练与测试

在训练网络时,一般的流程是通过前向计算获得网络的输出值,再通过损失函数计算网络误差,然后通过自动求导工具计算梯度并更新,同时间隔性地测试网络的性能。对于这种常用的训练逻辑,可以直接通过Keras 提供的模型装配与训练高层接口实现,简洁清晰。

文章目录

    • 一、模型装配
    • 二、模型匹配
    • 三、模型测试

一、模型装配

在 Keras 中,有2 个比较特殊的类:keras.Modelkeras.layers.Layer 类。
其中Layer类是网络层的母类,定义了网络层的一些常见功能,如添加权值,管理权值列表等。
Model 类是网络的母类,除了具有Layer 类的功能,还添加了保存、加载模型,训练与测
试模型等便捷功能。Sequential 也是Model 的子类,因此具有Model 类的所有功能。

# 创建5 层的全连接层网络
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()

创建网络后,正常的流程是通过循环迭代数据集多遍,每次按批产生训练数据,前向计算,然后通过损失函数计算误差值,并反向传播自动计算梯度,更新网络参数。这一部分
逻辑由于非常通用,在keras 中提供了compile()fit()函数方便实现上述逻辑。首先通过compile 函数指定网络使用的优化器对象,损失函数,评价指标等:

# 导入优化器,损失函数模块
from tensorflow.keras import optimizers,losses
# 采用Adam 优化器,学习率为0.01;采用交叉熵损失函数,包含Softmax
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'] # 设置测量指标为准确率
)

我们在compile()函数中指定的优化器,损失函数等参数也是我们自行训练时需要使用的参
数,并没有什么特别之处,只不过keras 将这部分常用逻辑实现了,提高开发效率。

二、模型匹配

模型装配完成后,即可通过fit()函数送入待训练的数据和验证用的数据集:

# 指定训练集为train_db,验证集为val_db,训练5 个epochs,每2 个epoch 验证一次
# 返回训练信息保存在history 中
history = network.fit(train_db, epochs=5, validation_data=val_db,
validation_freq=2)

其中train_db 为tf.data.Dataset 对象,也可以传入Numpy Array 类型的数据;
epochs 指定训练迭代的epochs 数;validation_data 指定用于验证(测试)的数据集和验证的频率validation_freq。
运行上述代码即可实现网络的训练与验证的功能,fit 函数会返回训练过程的数据记录
history,其中history.history 为字典对象,包含了训练过程中的loss,测量指标等记录项:

history.history # 打印训练记录

可以看到通过 compile&fit 方式实现的代码非常简洁和高效,大大缩减了开发时间。但是因为接口非常高层,灵活性也降低了,是否使用需要用户自行判断。

三、模型测试

Model 基类除了可以便捷地完成网络的装配与训练、验证,还可以非常方便的预测和测试。关于验证和测试的区别,我们会在过拟合章节详细阐述,此处可以将验证和测试理
解为模型评估的一种方式。
通过 Model.predict(x) 方法即可完成模型的预测:

# 加载一个batch 的测试数据
x,y = next(iter(db_test))
print('predict x:', x.shape)
out = network.predict(x) # 模型预测
print(out) // 其中out 即为网络输出

如果只是简单的测试模型的性能,可以通过Model.evaluate(db)即可循环测试完db 数据集上所有样本,并打印出性能指标:

network.evaluate(db_test) # 模型测试

你可能感兴趣的:(tensorflow)