两种方法均是重新定义一个类,在 fit 模型时 callbacks 调用该类
model.fit_generator(train_data=(X_patches_train, Y_labels_train), steps_per_epoch=len(X_patches_train)//self.batch_size, epochs=self.nb_epoch,
validation_data=(X_patches_valid, Y_labels_valid), verbose=1, callbacks=[checkpointer, SGDLearningRateTracker(),Mylosscallback(log_dir='./log')])
方法一:
利用tensorboard可视化
import tensorflow as tf
class Mylosscallback(Callback):
def __init__(self, log_dir):
super(Callback, self).__init__()
self.writer = tf.summary.FileWriter(log_dir)
self.num=0
def on_train_begin(self, logs={}):
self.losses = {'batch':[], 'epoch':[]}
self.accuracy = {'batch':[], 'epoch':[]}
self.val_loss = {'batch':[], 'epoch':[]}
self.val_acc = {'batch':[], 'epoch':[]}
def on_batch_end(self, batch, logs={}):
self.num=self.num+1
self.losses=logs.get('loss')
self.accuracy=logs.get('acc')
self.val_loss=logs.get('val_loss')
self.val_acc=logs.get('val_acc')
print('debug success!!!')
summary = tf.Summary()
summary.value.add(tag='losses', simple_value=self.losses)
summary.value.add(tag='accuracy', simple_value=self.accuracy)
summary.value.add(tag='val_loss', simple_value=self.val_loss)
summary.value.add(tag='val_acc', simple_value=self.val_acc)
self.writer.add_summary(summary, self.num)
self.writer.flush()
方法二:
利用 matplotlib.pyplot 画图,这里运行很慢,而且存在warning:Attribute Qt::AA_EnableHighDpiScaling must be set before QCoreApplication is created.
据百度,是pyQT的原因,待解决~
import keras.backend as K
import matplotlib.pyplot
class SGDLearningRateTracker(Callback): # 继承callback所有变量,定义新的回调函数
def on_train_begin(self, logs={}):
self.losses = {'batch':[], 'epoch':[]}
self.accuracy = {'batch':[], 'epoch':[]}
self.val_loss = {'batch':[], 'epoch':[]}
self.val_acc = {'batch':[], 'epoch':[]}
def on_batch_end(self, batch, logs={}):
self.losses['batch'].append(logs.get('loss'))
self.accuracy['batch'].append(logs.get('acc'))
self.val_loss['batch'].append(logs.get('val_loss'))
self.val_acc['batch'].append(logs.get('val_acc'))
#self.loss_plot('batch')
def on_epoch_end(self, epoch, logs={}):
self.losses['epoch'].append(logs.get('loss'))
self.accuracy['epoch'].append(logs.get('acc'))
self.val_loss['epoch'].append(logs.get('val_loss'))
self.val_acc['epoch'].append(logs.get('val_acc'))
self.loss_plot('epoch')
def loss_plot(self, loss_type):
iters = range(len(self.losses[loss_type]))
plt.figure()
# acc
plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
# loss
plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
if loss_type == 'epoch':
# val_acc
plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
# val_loss
plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
plt.grid(True)
plt.xlabel(loss_type)
plt.ylabel('acc-loss')
plt.legend(loc="upper right")
plt.show()
参考:https://blog.csdn.net/u014195530/article/details/82256333