keras在train_on_batch中使用generator

Generator是keras中很方便的数据输入方式,既可以节省内存空间,又自带数据增强的功能,一般用于fit_generator这种比较单一的训练方式,不适于train_on_batch这种拓展性较高的训练方式。但实际上generator是可以用于train_on_batch的,下面介绍具体方法:

理解generator

理解generator实际上理解yield关键词就够了,yield可以看作一个带指针的return,每次返回时指针指向程序停止的位置,因此下一次可以接着上一次运行。在外面调用generator只需要使用next()方法,具体可以参考这篇博客。

在train_on_batch中使用generator

因此在train_on_batch(x,y)中使用generator就相当于调用next(generator_x)和next(generator_y),代码如下:

x = xGenerator(128, dir_x)
y = yGenerator(128, dir_y)
for i in range(1,steps):
    loss = model.train_on_batch(next(x),next(y))

如代码所示,先声明两个迭代器x和y,然后构建一个循环,每次调用next(x)和next(y)就行了,其他与train_on_batch的普通使用方式相同,这样就可以在train_on_batch中使用方便的generator数据生成方式了。

你可能感兴趣的:(深度学习,深度学习,keras,迭代器)