keras2.4.3实现每个epoch实时显示训练和验证的精度和损失

一:编写显示的回调函数
from tensorflow.keras import callbacks
import matplotlib.pyplot as plt
import numpy as np
class LossHistory(callbacks.Callback):
    # 函数开始时创建盛放loss与acc的容器
    def on_train_begin(self, logs={}):
        self.losses = {'batch': [], 'epoch': []}
        self.accuracy = {'batch': [], 'epoch': []}
        self.val_loss = {'batch': [], 'epoch': []}
        self.val_acc = {'batch': [], 'epoch': []}
        self.H = {}
        self.figPath = './model_save/Training_and_testing_curve.png'#图像保存地址,可以自己设置

    # 按照batch来进行追加数据
    def on_batch_end(self, batch, logs={}):
        # 每一个batch完成后向容器里面追加loss,acc
        self.losses['batch'].append(logs.get('loss'))
        self.accuracy['batch'].append(logs.get('acc'))
        self.val_loss['batch'].append(logs.get('val_loss'))
        self.val_acc['batch'].append(logs.get('val_acc'))

    def on_epoch_end(self, batch, logs={}):
        # 每一个epoch完成后向容器里面追加loss,acc
        # self.losses['epoch'].append(logs.get('loss'))
        # self.accuracy['epoch'].append(logs.get('acc'))
        # self.val_loss['epoch'].append(logs.get('val_loss'))
        # self.val_acc['epoch'].append(logs.get('val_acc'))
        # 绘图
        for (k, v) in logs.items():
            l = self.H.get(k, [])
            l.append(v)
            self.H[k] = l
        if len(self.H["loss"]) > 1:
            N = np.arange(0, len(self.H["loss"]))
            plt.style.use("ggplot")
            plt.figure()
            plt.plot(N, self.H["loss"], label="train_loss")
            plt.plot(N, self.H["val_loss"], label="val_loss")
            plt.plot(N, self.H["accuracy"], label="train_acc")
            plt.plot(N, self.H["val_accuracy"], label="val_acc")
            plt.title("Training Loss and Accuracy [Epoch {}]".format(len(self.H["loss"])))
            plt.xlabel("Epoch #")
            plt.ylabel("Loss/Accuracy")
            plt.legend()
            plt.savefig(self.figPath)
            plt.show() #显示图像
            plt.close()


二.在模型fit之前,构造回调的call_backlist:
call_backlist = LossHistory()
三、在fit中设置参数
model.fit(self.train_data, self.train_label, validation_data=(self.test_data, self.test_label), epochs=300,
          batch_size=64, callbacks=[call_backlist])

你可能感兴趣的:(tensorflow,深度学习)