Tensorflow2.x tf.keras.callbacks函数分析

Tensorflow2.x callbacks函数分析

  • 前言
    • callbacks函数总览
    • 1. callbacks abstract
    • 2.callback 类中的方法运行顺序
    • 3.callbacks 中方法使用举例 入门自定义callback函数(以tensorboard为例)
      • 3.1 输入参数说明:
      • 3.2 set_model()函数

前言

你好! 博主在使用tensorflow做深度学习的研究时,发现这一块的内容水太深了,tensorflow目前最新版本是2.3.0,但是我尝试更新到tensorflow2.2.0的时候,遇到了很多问题,无奈又返回到2.1.0。博主一直在使用anconda虚拟环境,因为它可以很方便的管理python环境,可以做到与多个深度学习库共存。本次讨论的是tensorflow2.x版本中使用的 tf.keras.callbacks.xxx,最近从tensorflow1.1x换成了tensorflow2.x。

发现好多内容变化了,没了slim模块,对于之前网络架构使用该模块搭建的人来说,真的是个大坑。不过2.x版本中有引入的很多内容,很大程度上简化了编程的难度,使tensorflow入门的门槛降低了许多。

其中callbacks模块自带的几个函数用着真的很方便,如下:官网链接
          常用的函数已经加粗

class BaseLogger:累积指标的时期平均值的回调。

class CSVLogger:将纪元结果流式传输到csv文件的回调。

class Callback:用于建立新回调的抽象基类。(支持自定义callbacks)

class EarlyStopping:当监视的变量停止改善时,停止训练。

class History:将事件记录到History对象中的回调。

class LambdaCallback:用于即时创建简单,自定义回调的回调。

class LearningRateScheduler:学习率调度程序

class ModelCheckpoint:保存模型的时机。

class ProgbarLogger:将指标输出到标准输出的回调。

class ReduceLROnPlateau:当指标停止改善时,降低学习率。

class RemoteMonitor:用于将事件流传输到服务器的回调。

class TensorBoard:为TensorBoard启用可视化,(这个函数真的很强大)

class TerminateOnNaN:当遇到NaN丢失时回调将终止训练。

callbacks函数总览

为了方便研究callbacks函数的运行机制,下面将callbacks函数的基础架构拿出来,所有的callbacks函数的子类方法都是在这个基础上构建的。

@keras_export('keras.callbacks.Callback')
class Callback(object):
  """Abstract base class used to build new callbacks.

  Attributes:
      params: dict. Training parameters
          (eg. verbosity, batch size, number of epochs...).
      model: instance of `keras.models.Model`.
          Reference of the model being trained.
      validation_data: Deprecated. Do not use.

  The `logs` dictionary that callback methods
  take as argument will contain keys for quantities relevant to
  the current batch or epoch.

  Currently, the `.fit()` method of the `Model` class
  will include the following quantities in the `logs` that
  it passes to its callbacks:

      on_epoch_end: logs include `acc` and `loss`, and
          optionally include `val_loss`
          (if validation is enabled in `fit`), and `val_acc`
          (if validation and accuracy monitoring are enabled).
      on_batch_begin: logs include `size`,
          the number of samples in the current batch.
      on_batch_end: logs include `loss`, and optionally `acc`
          (if accuracy monitoring is enabled).
  """

  def __init__(self):
    self.validation_data = None
    self.model = None
    # Whether this Callback should only run on the chief worker in a
    # Multi-Worker setting.
    # TODO(omalleyt): Make this attr public once solution is stable.
    self._chief_worker_only = None

  def set_params(self, params):
    self.params = params

  def set_model(self, model):
    self.model = model

  @doc_controls.for_subclass_implementers
  def on_batch_begin(self, batch, logs=None):
    """A backwards compatibility alias for `on_train_batch_begin`."""

  @doc_controls.for_subclass_implementers
  def on_batch_end(self, batch, logs=None):
    """A backwards compatibility alias for `on_train_batch_end`."""

  @doc_controls.for_subclass_implementers
  def on_epoch_begin(self, epoch, logs=None):
    """Called at the start of an epoch.
    """

  @doc_controls.for_subclass_implementers
  def on_epoch_end(self, epoch, logs=None):
    """Called at the end of an epoch.
    """

  @doc_controls.for_subclass_implementers
  def on_train_batch_begin(self, batch, logs=None):
    """Called at the beginning of a training batch in `fit` methods.
    """
    # For backwards compatibility.
    self.on_batch_begin(batch, logs=logs)

  @doc_controls.for_subclass_implementers
  def on_train_batch_end(self, batch, logs=None):
    """Called at the end of a training batch in `fit` methods.
    """
    # For backwards compatibility.
    self.on_batch_end(batch, logs=logs)

  @doc_controls.for_subclass_implementers
  def on_test_batch_begin(self, batch, logs=None):
    """Called at the beginning of a batch in `evaluate` methods.
    """

  @doc_controls.for_subclass_implementers
  def on_test_batch_end(self, batch, logs=None):
    """Called at the end of a batch in `evaluate` methods.
    """

  @doc_controls.for_subclass_implementers
  def on_predict_batch_begin(self, batch, logs=None):
    """Called at the beginning of a batch in `predict` methods.
    """

  @doc_controls.for_subclass_implementers
  def on_predict_batch_end(self, batch, logs=None):
    """Called at the end of a batch in `predict` methods.
    """

  @doc_controls.for_subclass_implementers
  def on_train_begin(self, logs=None):
    """Called at the beginning of training.
    """

  @doc_controls.for_subclass_implementers
  def on_train_end(self, logs=None):
    """Called at the end of training.
    """

  @doc_controls.for_subclass_implementers
  def on_test_begin(self, logs=None):
    """Called at the beginning of evaluation or validation.
    """

  @doc_controls.for_subclass_implementers
  def on_test_end(self, logs=None):
    """Called at the end of evaluation or validation.
    """

  @doc_controls.for_subclass_implementers
  def on_predict_begin(self, logs=None):
    """Called at the beginning of prediction.
    """

  @doc_controls.for_subclass_implementers
  def on_predict_end(self, logs=None):
    """Called at the end of prediction.
    """

    接下来具体讲解里面每个函数的含义,传入的参数,调用的时机方法。为了方便展开分析,下面将callbacks类的方法单独罗列出来:

@keras_export('keras.callbacks.Callback')
class Callback(object):
  def __init__(self):
  def set_params(self, params):
  def set_model(self, model):
  def on_batch_begin(self, batch, logs=None):
  def on_batch_end(self, batch, logs=None):
  def on_epoch_begin(self, epoch, logs=None):
  def on_epoch_end(self, epoch, logs=None):
  def on_train_batch_begin(self, batch, logs=None):
  def on_train_batch_end(self, batch, logs=None):
  def on_test_batch_begin(self, batch, logs=None):
  def on_test_batch_end(self, batch, logs=None):
  def on_predict_batch_begin(self, batch, logs=None):
  def on_predict_batch_end(self, batch, logs=None):
  def on_train_begin(self, logs=None):
  def on_train_end(self, logs=None):
  def on_test_begin(self, logs=None):
  def on_test_end(self, logs=None):
  def on_predict_begin(self, logs=None):
  def on_predict_end(self, logs=None):

1. callbacks abstract

下面详细介绍每个函数的原理和使用方法,在基类的代码里首先给出了关于callbacks函数的的abstract。(下面是博主的一些见解,如果有讲错的地方,欢迎批评指正,博主会及时更新,共同学习,共同进步。)

  """Abstract base class used to build new callbacks.

  Attributes:
      params: dict. Training parameters
          (eg. verbosity, batch size, number of epochs...).
      model: instance of `keras.models.Model`.
          Reference of the model being trained.
      validation_data: Deprecated. Do not use.

  The `logs` dictionary that callback methods
  take as argument will contain keys for quantities relevant to
  the current batch or epoch.

  Currently, the `.fit()` method of the `Model` class
  will include the following quantities in the `logs` that
  it passes to its callbacks:

      on_epoch_end: logs include `acc` and `loss`, and
          optionally include `val_loss`
          (if validation is enabled in `fit`), and `val_acc`
          (if validation and accuracy monitoring are enabled).
      on_batch_begin: logs include `size`,
          the number of samples in the current batch.
      on_batch_end: logs include `loss`, and optionally `acc`
          (if accuracy monitoring is enabled).
  """

Attributes:
    1.任何类中的方法在处理外部数据的时候,都需要传入相应的参数,这里的callbacks函数作为keras的子模块,需要在网络训练的时候调用,用来处理网络训练时的信息,自然需要传入一些参数。
    2. Attributes里面罗列了params, model,讲明了类中可以使用的参数的类型。我暂时这样理解的,因为在后面的分析可以发现,不管是基类还是子类,参数的传递是自动进行的,我们可以自定义的使用,但是使用的必须是Attributes中给定的参数。例如在tf.keras.callbacks.tensorboard类中,有时需要设置histogram,那么就需要获得网络中各层权重信息,从而就需要获得整个网络模型,在实际使用时直接使用关键字 self.model = model, 就完成了对于基类callbacks外部输入数据的输入操作。此时的self.model已经变成了类的私有参数。
    3.logs是callback方法中的一个字典,在上面的分析中可以知道,几乎所有的方法都用到了这个参数,并且都赋了初始值None,目前这种机制暂时不清楚,后面的注释说明了在使用model.fit()时,callback类,就自动将参数传入到了logs当中去。并且说明了在不同的训练阶段logs中所包含的内容:(例如)

1.on_epoch_end:   logs include acc and loss, and optionally include val_loss
2.on_batch_begin:   logs include size
3.on_batch_end:    logs include loss, and optionally acc


2.callback 类中的方法运行顺序

首先,一个训练的流程 大致如下:(一般predict过程与训练过程分离,所以为了使阅读更加简洁,去掉了四个方法)


1.初始化过程
__init__(self):
set_params(self, params):

2.训练开始
set_model(self, model):
on_train_begin(self, logs=None):
   #循环体 遍历每个epoch
   on_epoch_begin(self, epoch, logs=None):
   -   #循环体 遍历每个 train_batch
   -    on_train_batch_begin(self, batch, logs=None): -> on_batch_begin(self, batch, logs=None)<-(该函数不是必需的,一般在这个方法内部)
   -    ''' callback 外训练进行训练操作'''
   -    on_train_batch_end(self, batch, logs=None):   -> on_batch_end(self, batch, logs=None):<-(该函数不是必需的,一般在这个方法内部)
   -   #循环体 遍历每个 test_batch 
   -    on_test_begin(self, logs=None):
   -    on_test_batch_begin(self, batch, logs=None):
   -    on_test_batch_end(self, batch, logs=None):
   -    on_test_end(self, logs=None):
   - 
   on_epoch_end(self, epoch, logs=None):
on_train_end(self, logs=None):

3.callbacks 中方法使用举例 入门自定义callback函数(以tensorboard为例)

3.1 输入参数说明:

Arguments:
log_dir: the path of the directory where to save the log files to be parsed by TensorBoard.
histogram_freq: frequency (in epochs) at which to compute activation and weight histograms for the layers of
        the model. If set to 0, histogram won’t be computed. Validation data
         (or split) must be specified for histogram visualizations.
write_graph: whether to visualize the graph in TensorBoard. The log file can become quite large when
         write_graph is set to True.
write_images: whether to write model weights to visualize as image in TensorBoard.
update_freq: 'batch' or 'epoch' or integer. When using 'batch',
       writes the losses and metrics to TensorBoard after each batch. The same
       applies for 'epoch'. If using an integer, let’s say 1000, the
       callback will write the metrics and losses to TensorBoard every 1000
       batches. Note that writing too frequently to TensorBoard can slow down your training.
profile_batch: Profile the batch to sample compute characteristics. By default, it will profile the second batch.
       Set profile_batch=0 to disable profiling. Must run in TensorFlow eager mode.查看详细作用请点击
embeddings_freq: frequency (in epochs) at which embedding layers will be visualized. If set to 0, embeddings
       won’t be visualized.
embeddings_metadata: a dictionary which maps layer name to a file name in which metadata for this
       embedding layer is saved. See the
details
about metadata files format. In case if the same metadata file is
used for all embedding layers, string can be passed.

  def __init__(self,
               log_dir='logs',
               histogram_freq=0,
               write_graph=True,
               write_images=False,
               update_freq='epoch',
               profile_batch=2,
               embeddings_freq=0,
               embeddings_metadata=None,
               **kwargs):
    super(TensorBoard, self).__init__()
    self._validate_kwargs(kwargs)

    self.log_dir = log_dir
    self.histogram_freq = histogram_freq
    self.write_graph = write_graph
    self.write_images = write_images
    if update_freq == 'batch':
      self.update_freq = 1
    else:
      self.update_freq = update_freq
    self.embeddings_freq = embeddings_freq
    self.embeddings_metadata = embeddings_metadata

    self._samples_seen = 0
    self._samples_seen_at_last_write = 0
    self._current_batch = 0

    # A collection of file writers currently in use, to be closed when
    # training ends for this callback. Writers are keyed by the
    # directory name under the root logdir: e.g., "train" or
    # "validation".
    self._train_run_name = 'train'
    self._validation_run_name = 'validation'
    self._writers = {}

    self._profile_batch = profile_batch
    # True when a trace is running.
    self._is_tracing = False

3.2 set_model()函数

  def set_model(self, model):
    """Sets Keras model and writes graph if specified."""
    self.model = model #model 是当前训练的网络,该参数可理解为自动传入的,并将其私有化

    # TensorBoard callback involves writing a summary file in a
    # possibly distributed settings.
    self._log_write_dir = distributed_file_utils.write_dirpath(
        self.log_dir, self.model._get_distribution_strategy())  # pylint: disable=protected-access

    with context.eager_mode():
      self._close_writers()
      if self.write_graph:
        with self._get_writer(self._train_run_name).as_default():
          with summary_ops_v2.always_record_summaries():
            if not model.run_eagerly:
              summary_ops_v2.graph(K.get_graph(), step=0)

            summary_writable = (
                self.model._is_graph_network or  # pylint: disable=protected-access
                self.model.__class__.__name__ == 'Sequential')  # pylint: disable=protected-access
            if summary_writable:
              summary_ops_v2.keras_model('keras', self.model, step=0)

    if self.embeddings_freq:
      self._configure_embeddings()

    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
    self._prev_summary_recording = summary_state.is_recording
    self._prev_summary_writer = summary_state.writer
    self._prev_summary_step = summary_state.step

你可能感兴趣的:(Tensorflow2.x使用)