keras.callback fit_generator

1.fit_generator

fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose=1, callbacks=[], validation_data=None, nb_val_samples=None, class_weight=None, max_q_size=10)

函数的参数是:

generator:生成器函数,生成器的输出应该为:
一个形如(inputs,targets)的tuple

一个形如(inputs, targets,sample_weight)的tuple。所有的返回值都应该包含相同数目的样本。生成器将无限在数据集上循环。每个epoch以经过模型的样本数达到samples_per_epoch时,记一个epoch结束

    def next_train(self):
        while 1:
            ret = self.get_batch(self.cur_train_index, self.minibatch_size, train=True)
            self.cur_train_index += self.minibatch_size
            if self.cur_train_index >= self.val_split:
                self.cur_train_index = self.cur_train_index % 32
                (self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
                    [self.X_text, self.Y_data, self.Y_len], self.val_split)
            yield ret

这里写成了一个死循环while True,因为model.fit_generator()在使用在个函数的时候, 并不会在每一个epoch之后重新调用,那么如果这时候generator自己结束了就会有问题。

samples_per_epoch:整数,当模型处理的样本达到此数目时计一个epoch结束,执行下一个epoch

verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录

validation_data:具有以下三种形式之一

生成验证集的生成器

一个形如(inputs,targets)的tuple

一个形如(inputs,targets,sample_weights)的tuple

nb_val_samples:仅当validation_data是生成器时使用,用以限制在每个epoch结束时用来验证模型的验证集样本数,功能类似于samples_per_epoch

max_q_size:生成器队列的最大容量

函数返回一个History对象

2.fit_generator 训练逻辑过程

model.fit_generator 训练入口函数(参考上面的函数原型定义)

   callbacks.on_train_begin()
     while epoch < epochs:
             callbacks.on_epoch_begin(epoch)
             while steps_done < steps_per_epoch:
             	#generator_output是一个死循环while True,因为model.fit_generator()在使用在个函数的时候, 并不会在每一个epoch之后重新调用,那么如果这时候generator自己结束了就会有问题。
                 generator_output = next(output_generator)       #生成器next函数取输入数据进行训练,每次取一个batch大小的量
                 callbacks.on_batch_begin(batch_index, batch_logs)
                 outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
                 callbacks.on_batch_end(batch_index, batch_logs)	
              end of while steps_done < steps_per_epoch	
              self.evaluate_generator(...)          #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估	
              callbacks.on_epoch_end(epoch, epoch_logs)          #在这个执行过程中实现模型数据的保存操作
     end of while epoch < epochs	
     callbacks.on_train_end()
``
# 回调函数
通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。eras的回调函数是一个类

```python

keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类

3.类属性

params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)

model:keras.models.Model对象,为正在训练的模型的引用回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。

目前,模型的.fit()中有下列参数会被记录到logs中:

在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=[‘accuracy’]。

在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数

在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc

你可能感兴趣的:(tensorflow)