【keras】可视化loss曲线

两种方法均是重新定义一个类,在 fit 模型时 callbacks 调用该类

model.fit_generator(train_data=(X_patches_train, Y_labels_train), steps_per_epoch=len(X_patches_train)//self.batch_size, epochs=self.nb_epoch,
                                 validation_data=(X_patches_valid, Y_labels_valid), verbose=1, callbacks=[checkpointer, SGDLearningRateTracker(),Mylosscallback(log_dir='./log')])

方法一:

利用tensorboard可视化

import tensorflow as tf

class Mylosscallback(Callback):
    def __init__(self, log_dir):
        super(Callback, self).__init__()  
        self.writer = tf.summary.FileWriter(log_dir)
        self.num=0
    def on_train_begin(self, logs={}):
        self.losses = {'batch':[], 'epoch':[]}
        self.accuracy = {'batch':[], 'epoch':[]}
        self.val_loss = {'batch':[], 'epoch':[]}
        self.val_acc = {'batch':[], 'epoch':[]}
    
    def on_batch_end(self, batch, logs={}):
        self.num=self.num+1
        self.losses=logs.get('loss')
        self.accuracy=logs.get('acc')
        self.val_loss=logs.get('val_loss')
        self.val_acc=logs.get('val_acc')
        print('debug success!!!')
        summary = tf.Summary()
        summary.value.add(tag='losses', simple_value=self.losses)
        summary.value.add(tag='accuracy', simple_value=self.accuracy)
        summary.value.add(tag='val_loss', simple_value=self.val_loss)
        summary.value.add(tag='val_acc', simple_value=self.val_acc)     
        self.writer.add_summary(summary, self.num)
        self.writer.flush()

方法二:

利用 matplotlib.pyplot 画图,这里运行很慢,而且存在warning:Attribute Qt::AA_EnableHighDpiScaling must be set before QCoreApplication is created.

据百度,是pyQT的原因,待解决~

import keras.backend as K
import matplotlib.pyplot

class SGDLearningRateTracker(Callback):  # 继承callback所有变量,定义新的回调函数        
    def on_train_begin(self, logs={}):
        self.losses = {'batch':[], 'epoch':[]}
        self.accuracy = {'batch':[], 'epoch':[]}
        self.val_loss = {'batch':[], 'epoch':[]}
        self.val_acc = {'batch':[], 'epoch':[]}
 
    def on_batch_end(self, batch, logs={}):
        self.losses['batch'].append(logs.get('loss'))
        self.accuracy['batch'].append(logs.get('acc'))
        self.val_loss['batch'].append(logs.get('val_loss'))
        self.val_acc['batch'].append(logs.get('val_acc'))
        #self.loss_plot('batch')
    def on_epoch_end(self, epoch, logs={}):
        self.losses['epoch'].append(logs.get('loss'))
        self.accuracy['epoch'].append(logs.get('acc'))
        self.val_loss['epoch'].append(logs.get('val_loss'))
        self.val_acc['epoch'].append(logs.get('val_acc'))
        self.loss_plot('epoch')

     def loss_plot(self, loss_type):
        iters = range(len(self.losses[loss_type]))
        plt.figure()
        # acc
        plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
        # loss
        plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
        if loss_type == 'epoch':
            # val_acc
            plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
            # val_loss
            plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
        plt.grid(True)
        plt.xlabel(loss_type)
        plt.ylabel('acc-loss')
        plt.legend(loc="upper right")
        plt.show()

参考:https://blog.csdn.net/u014195530/article/details/82256333

你可能感兴趣的:(Keras)