Tensorboard 中加入matplotlib.pyplot输出的图

查了不少网站,在这介绍一段在使用tensorflow.keras进行神经网络训练的时候,通过编写一个tf.python.keras.callbacks.Callback的子类,实现在每个或几个epoch结束后,在Tensorboard的Image中显示matplotlib.pyplot生成的图片的方法。

import io
import tensorflow as tf
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.summary import summary as tf_summary
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import numpy as np

class MyCallbacks(Callback):
    def __init__(self, logdir, period, val_data):
        super(MyCallbacks, self).__init__()
        self.logdir = logdir
        self.period = period
        self.last_rcd = 0
        self.writer = tf_summary.FileWriter(self.logdir)
        self.validation_data = val_data

    def gen_plot(self, y_predict):
        real_part = np.reshape(y_predict, [-1]) # vectorize y_predict
        imag_part = np.reshape(y_predict,[-1])
        plt.figure()
        plt.scatter(real_part, imag_part)
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        return buf


    def on_epoch_end(self, epoch, logs=None):
        self.last_rcd = self.last_rcd + 1
        if self.last_rcd >= self.period:
            self.last_rcd = 0
            y_predict = self.model.predict(self.validation_data, steps = 32)

            # Prepare the plot
            plot_buf = self.gen_plot(y_predict)

            # Convert PNG buffer to TF image
            image = tf.image.decode_png(plot_buf.getvalue(), channels=4)

            # Add the batch dimension
            image = tf.expand_dims(image, 0)

            # Add image summary           
            with tf.Session() as sess:
                # Run
                summary_op = tf.summary.image("plot", image)
                summary = sess.run(summary_op)
                # Write summary
                
                self.writer.add_summary(summary)
          
    def on_train_end(self, logs=None):
        self.writer.close()

在tf.keras下训练的时候如下调用:

callbacks = [
        # Write TensorBoard logs to `./logs` directory
        tf.keras.callbacks.TensorBoard(log_dir=log_dir, write_graph=True,write_grads=True, write_images = False),
        # Create checkpoint callback
        tf.keras.callbacks.ModelCheckpoint(checkpoint_path, verbose=1, save_weights_only=True, period=50),
        ConstellationCallbacks(logdir = log_dir, period = 10, val_data = my_val_data)
    ]

model.fit(training_data, training_label, epochs=2000, batch_size=256, shuffle = True,
              callbacks=callbacks,validation_data=(my_val_data, my_val_label))

 

你可能感兴趣的:(AI)