TensorFlow: visualization of train process using `train_on_batch`

Call the function train_on_batch_TB

def train_on_batch_TB(nb_epoch, batch_size, model,
                      X, Y, validation_x, validation_y,
                      model_file_path, TB_log_dir='./logs', ):
    """
    train model by train_on_batch function
    :param nb_epoch: number of epochs
    :param batch_size:
    :param model:
    :param X: X shape is (samples, timesteps, features)
    :param Y: Y shape is (samples, timesteps)
    :param validation_x:
    :param validation_y:
    :param model_file_path:
    :param TB_log_dir: folder of logs for tensorBoard
    """
    callback = TensorBoard(TB_log_dir)
    callback.set_model(model)
    min_val_loss = float('inf')  # store min value of validation loss
    for epoch_no in range(nb_epoch):
        print(f'{epoch_no}th train')
        n_batch_per_epoch = len(X) // batch_size
        batches_logs = np.zeros((n_batch_per_epoch, len(model.metrics_names)))  # collect all batches 
        for batch_no in range(n_batch_per_epoch):
            batch_start_index = batch_no * batch_size
            _x, _y = X[batch_start_index:batch_start_index + batch_size], \
                     Y[batch_start_index:batch_start_index + batch_size]
            # train
            logs = model.train_on_batch(_x, _y)
            batches_logs[batch_no] = [logs] if type(logs) != list else logs

            if (batch_no + 1) % n_batch_per_epoch == 0:
                write_log(callback, model.metrics_names, np.mean(batches_logs, axis=0), epoch_no)                
                val_loss = validate(model, batch_size, validation_x, validation_y)
                write_log(callback, ['val_loss'], val_loss, epoch_no)
                # save model
                if isinstance(val_loss, Number) and val_loss < min_val_loss:
                    print(f"val_loss {val_loss:.4f} < {min_val_loss:.4f}, save the model.")
                    min_val_loss = val_loss
                    model.save(model_file_path, overwrite=True, include_optimizer=True)
                model.reset_states()


def validate(model, batch_size,
             validation_x, validation_y):
    n_batch = len(validation_x) // batch_size
    batches_logs = np.zeros((n_batch, len(model.metrics_names))) 
    for batch_no in range(n_batch):
        batch_start_index = batch_no * batch_size
        _x, _y = validation_x[batch_start_index:batch_start_index + batch_size], \
                 validation_y[batch_start_index:batch_start_index + batch_size]
        # train and validate
        logs = model.train_on_batch(_x, _y)
        batches_logs[batch_no] = [logs] if type(logs) != list else logs
    return np.mean(batches_logs, axis=0)[0]


def write_log(callback, names, logs, batch_no):
    if isinstance(logs, Number): # check
        logs = [logs]
    for name, value in zip(names, logs):
        summary = tf.Summary()
        summary_value = summary.value.add()
        summary_value.simple_value = value
        summary_value.tag = name
        callback.writer.add_summary(summary, batch_no)
        callback.writer.flush()

你可能感兴趣的:(TensorFlow: visualization of train process using `train_on_batch`)