fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose=1, callbacks=[], validation_data=None, nb_val_samples=None, class_weight=None, max_q_size=10)
函数的参数是:
generator:生成器函数,生成器的输出应该为:
一个形如(inputs,targets)的tuple
一个形如(inputs, targets,sample_weight)的tuple。所有的返回值都应该包含相同数目的样本。生成器将无限在数据集上循环。每个epoch以经过模型的样本数达到samples_per_epoch时,记一个epoch结束
def next_train(self):
while 1:
ret = self.get_batch(self.cur_train_index, self.minibatch_size, train=True)
self.cur_train_index += self.minibatch_size
if self.cur_train_index >= self.val_split:
self.cur_train_index = self.cur_train_index % 32
(self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
[self.X_text, self.Y_data, self.Y_len], self.val_split)
yield ret
这里写成了一个死循环while True,因为model.fit_generator()在使用在个函数的时候, 并不会在每一个epoch之后重新调用,那么如果这时候generator自己结束了就会有问题。
samples_per_epoch:整数,当模型处理的样本达到此数目时计一个epoch结束,执行下一个epoch
verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
validation_data:具有以下三种形式之一
生成验证集的生成器
一个形如(inputs,targets)的tuple
一个形如(inputs,targets,sample_weights)的tuple
nb_val_samples:仅当validation_data是生成器时使用,用以限制在每个epoch结束时用来验证模型的验证集样本数,功能类似于samples_per_epoch
max_q_size:生成器队列的最大容量
函数返回一个History对象
model.fit_generator 训练入口函数(参考上面的函数原型定义)
callbacks.on_train_begin()
while epoch < epochs:
callbacks.on_epoch_begin(epoch)
while steps_done < steps_per_epoch:
#generator_output是一个死循环while True,因为model.fit_generator()在使用在个函数的时候, 并不会在每一个epoch之后重新调用,那么如果这时候generator自己结束了就会有问题。
generator_output = next(output_generator) #生成器next函数取输入数据进行训练,每次取一个batch大小的量
callbacks.on_batch_begin(batch_index, batch_logs)
outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
callbacks.on_batch_end(batch_index, batch_logs)
end of while steps_done < steps_per_epoch
self.evaluate_generator(...) #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估
callbacks.on_epoch_end(epoch, epoch_logs) #在这个执行过程中实现模型数据的保存操作
end of while epoch < epochs
callbacks.on_train_end()
``
# 回调函数
通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。eras的回调函数是一个类
```python
keras.callbacks.Callback()
这是回调函数的抽象类,定义新的回调函数必须继承自该类
params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)
model:keras.models.Model对象,为正在训练的模型的引用回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。
目前,模型的.fit()中有下列参数会被记录到logs中:
在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=[‘accuracy’]。
在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数
在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc