fit_generator()


在数据处理和网络定义完成后,跑模型时突然出现了错误:
OOM

刚开始也不知道哪里的问题,发现有可能是内存耗尽了,然后就放进去500张图片进行fit,然后问题就消失了,猜想应该是数据太大,内存开销不够。
发现官方文档中说可以使用fit_generator()分批训练。


官方文档如下:

fit_generator(self, generator, 
                    steps_per_epoch=None, 
                    epochs=1, 
                    verbose=1, 
                    callbacks=None, 
                    validation_data=None, 
                    validation_steps=None,  
                    class_weight=None,
                    max_queue_size=10,   
                    workers=1, 
                    use_multiprocessing=False, 
                    shuffle=True, 
                    initial_epoch=0)

通过Python generator产生一批批的数据用于训练模型。generator可以和模型并行运行,例如,可以使用CPU生成批数据同时在GPU上训练模型。

参数:
  • generator:一个generator或Sequence实例,为了避免在使用multiprocessing时直接复制数据。
  • steps_per_epoch:从generator产生的步骤的总数(样本批次总数)。通常情况下,应该等于数据集的样本数量除以批量的大小。
  • epochs:整数,在数据集上迭代的总数。
  • works:在使用基于进程的线程时,最多需要启动的进程数量。
  • use_multiprocessing:布尔值。当为True时,使用基于基于过程的线程。

例子:
datagen = ImageDataGenator(...)
model.fit_generator(datagen.flow(x_train, y_train,
                                 batch_size=batch_size),
                    epochs=epochs,
                    validation_data=(x_test, y_test),
                    workers=4)

你可能感兴趣的:(fit_generator())