Colab Tensorboard 批次(batch)级数据显示

前言

太坑了。。查了很多资料,终于解决了。

方法

首先,根据官方API的参数列表得知,在tf.keras.callbacks.TensorBoard中修改update_freq参数为batch或一个整数

update_freq='batch'

# update_freq=10
# 如果改为使用一个整数N的话,过N批次后更新一次数据,
# 这样可以避免由于更新过于频繁而降低网络训练速度

然后,根据这则帖子,由于TensorFlow 2.3做了一个优化,导致上面的方法在这里不管用。

解决方法是,除了TensorBoard之类的callback以外,再添加一个LambdaCallback,具体代码如下:

   def batchOutput(batch, logs):
       tf.summary.scalar('batch_loss', data=logs['loss'], step=batch)
       tf.summary.scalar('batch_accuracy', data=logs['accuracy'], step=batch)
       return batch
       
   batch_log_callback = callbacks.LambdaCallback(
       on_batch_end=batchOutput)

于是终于成功
Colab Tensorboard 批次(batch)级数据显示_第1张图片Colab Tensorboard 批次(batch)级数据显示_第2张图片

示例代码

改完后,我的训练部分的完整代码是这样的:

def train_model(save:bool=True):
    # load and compile model
    model = create_model()    
    model.compile(
        loss='mean_squared_error',
        optimizer='adam',
        metrics=['accuracy'])

    # prepare tensorflow dashboard
    logdir = os.path.join(
        'logs', 
        datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    tensorboard_callback = callbacks.TensorBoard(
        logdir, 
        histogram_freq=1, 
        write_images=True, 
        update_freq=10, # 查看批次级别的数据变化,需要结合LambdaCallback
        embeddings_freq=1,
        profile_batch=1)
        
	# 实现查看批次级别数据变化
    def batchOutput(batch, logs):
        tf.summary.scalar('batch_loss', data=logs['loss'], step=batch)
        tf.summary.scalar('batch_accuracy', data=logs['accuracy'], step=batch)
        return batch
    batch_log_callback = callbacks.LambdaCallback(
        on_batch_end=batchOutput)

    # prepare early stop
    early_stop_callback = callbacks.EarlyStopping(
        monitor='val_loss', 
        patience=0,
        restore_best_weights=True)

    # train model
    epochs_num = 4
    model.fit(x=X,
              y=X, 
              epochs=epochs_num, 
              batch_size=64, 
              validation_data=(X_eval, X_eval),
              verbose=1, # 0:silent, 1:progress bar, 2:one line per epoch
              callbacks=[tensorboard_callback, 
                         batch_log_callback,
                         early_stop_callback])

    # save model
    if save:
        MODEL_FOLDER = '/content/drive/MyDrive/A-Million-Headlines/pretrained'
        model_name = 'AutoEncoder-model-{}-epochs-{}.h5'.format(epochs_num, int(time.time()))
        joblib.dump(model, os.path.join(MODEL_FOLDER, model_name))
    
    return model

参考资料

  • 官方API - Tensorflow
  • Tensorboard not updating by batch in google colab - StackOverflow

你可能感兴趣的:(机器学习,batch,tensorflow,python,Colab,tensorboard)