TensorFlow2.0学习笔记——回调函数callbacks

这里的回调函数介绍得并不详细,只记录了笔者学习过程中用到的,后续随着学习,会逐渐补充。

1. 回调函数的理解

笔者目前的理解,回调函数是在程序运行中,满足某些要求,就会触发的函数。

2. tf.keras.callbacks

官网地址点击此处

2.1 TensorBoard

官网地址点击此处
源码点击此处
Tensorboard是TensorFlow内置的可视化工具,记录TensorBoard中的事件,包括Metrics summary plots(指标摘要图,指标即损失loss或准确度accuracy),Training graph visualization(训练图可视化),Activation histograms(激活直方图), Sampled profiling(采样分析)。
具体的,可实现如下功能:

  • 对如损失和准确度等指标进行跟踪并实现可视化。
  • 对模型图进行可视化,比如操作和层(ops and layers)
  • 查看权重(weight)、偏差(biases)或其他张量(tensor)随时间变化的直方图
  • Projecting embeddings to a lower dimensional space(将嵌入物投影到较低维度的空间)。这句还不能理解是什么意思
  • 展示图像、文本和音频数据
  • 分析一个TensorFlow程序
  • 其他~~~

2.1.1 启动

可在命令行中输入如下命令来启动TensorBoard。

tensorboard --logdir=path_to_your_logs

2.1.2 参数

__init__(
    log_dir='logs',
    histogram_freq=0,
    write_graph=True,
    write_images=False,
    update_freq='epoch',
    profile_batch=2,
    embeddings_freq=0,
    embeddings_metadata=None,
    **kwargs
)
  • log_dir: 保存TensorBoard解析文件的路径。
  • histogram_freq: 计算模型各层的激活度和权重直方图的频率(每个周期中)。如果设置为0,将不计算直方图。必须为直方图可视化指定验证数据(或拆分)。
  • write_gragh: 是否在TensorBoard中可视化图形。当write_graph设置为True时,文件可能会变得很大。
  • write_images: 是否在TensorBoard中编写模型权重来实现可视化的图片。
  • update_freq: 输入“batch”或“epoch”或整数,在每一个batch或epoch或整数个数据(samples)结束后将损失(loss)和指标(metrics)添加到TensorBoard中。
  • profile_batch: 对要采样的批次进行分析,计算特征。默认为2,profile_batch为0时禁用。在eager mode中必须使用。
  • embeddings_freq: 嵌入层可视化的频率,被设置为0,则不可见
  • embeddings_metadata: 字典,它将图层名称映射到文件名,该文件名的文件保存着嵌入层的元数据。

2.2 示例

此处示例来自此处
以下代码为节选

# 回调函数需在拟合之前设置
# 回调函数,使用Tensorboard,earlystopping,ModelCheckpoint
# Tensorboard需要使用一个文件夹
# ModelCheckpoint需要一个文件名
logdir = './callbacks'
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir, 
                                 "fashion_mnist_model.h5")
callbacks = [
    keras.callbacks.TensorBoard(logdir),
    keras.callbacks.ModelCheckpoint(output_model_file,
                                   save_best_only=True),
    keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)
]


history = model.fit(x_train, y_train, epochs=10,
                    validation_data=(x_valid, y_valid),
                    callbacks = callbacks)
# epochs设置数据遍历的次数,validation_data用来设置训练中检验模型的测试集

运行后,目录下出现tensorboard文件夹,在命令行中输入tensorboard --logdir=callbacks,其中,callbacks是我们创建的tensorboard文件夹名称。

tensorboard --logdir=callbacks

笔者的环境下输出结果如下

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.1.0 at http://localhost:6006/ (Press CTRL+C to quit)

打开浏览器,输入localhost:6006,进入.
在这里插入图片描述
界面如下:
TensorFlow2.0学习笔记——回调函数callbacks_第1张图片

你可能感兴趣的:(深度学习基础)