keras使用callback造自己的monitor函数

fit_generator函数

在这里插入图片描述

callback类

keras使用callback造自己的monitor函数_第1张图片

keras.callbacks.ModelCheckpoint是一个常见的callback类,其重写了on_epoch_end函数,在每个epoch结束保存模型数据进入文件。

  • keras.callbacks.History类主要记录每一次epoch训练的结果,包含loss以及acc的值;
  • keras.callbacks.ProgbarLogger类实现训练中间状态数据信息的输出,主要涉及进度相关信息。
  1. 训练过程中,每次权重的更新都是在一个batch上进行一次,是基于batch量的数据为单位进行权重的更新;

  2. 基于生成器模型训练数据,可以提升效率,降低对物理服务器性能,尤其是内存的要求;

  3. 训练过程中,callback函数执行了大量的工作,包括loss、acc值的记录,以及训练中间结果的日志反馈,最重要的是模型数据的输出,也是通过callback的方式实现;

  4. 训练和验证的逻辑近乎一样,训练要更新权重,但是验证过程,仅仅更新网络状态,不涉及权重(loss以及acc参数)信息的更新;

  5. Keras采用了生成器,装饰器,回调等编程思想,另外,对矩阵运算,例如numpy.dot以及numpy.multiply的数学逻辑都有一定要求,对python编程要求还是比较高滴。

我的新的callback函数

class F1ScoreCallback(Callback):
    def __init__(self, predict_batch_size=1024, include_on_batch=False):
        super(F1ScoreCallback, self).__init__()
        self.predict_batch_size = predict_batch_size
        self.include_on_batch = include_on_batch

    def on_batch_begin(self, batch, logs={}):
        pass

    def on_train_begin(self, logs={}):
        if not ('avg_f1_score_val' in self.params['metrics']):
            self.params['metrics'].append('avg_f1_score_val')

    def on_batch_end(self, batch, logs={}):
        if (self.include_on_batch):
            logs['avg_f1_score_val'] = float('-inf')

    def on_epoch_end(self, epoch, logs={}):
        logs['avg_f1_score_val'] = float('-inf')
        if (self.validation_data):
            y_predict, predict2, predict3 = self.model.predict(self.validation_data[0],
                                           batch_size=self.predict_batch_size)
            y_predict[y_predict >= 0.5] = 1
            y_predict[y_predict < 0.5] = 0
            f1 = f1_score(self.validation_data[1], y_predict, average='macro')
            # print("macro f1_score %.4f " % f1)
            f2 = f1_score(self.validation_data[1], y_predict, average='micro')
            # print("micro f1_score %.4f " % f2)
            avgf1=(f1 + f2) / 2
            # print("avg_f1_score %.4f " % (avgf1))
            logs['avg_f1_score_val'] =avgf1

只要加入log中,然后在monitor选择你自己定义的log函数名字,就能在一个epoch后的val中运行

你可能感兴趣的:(深度学习)