如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。
fit_generator(generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0)
参数:
样例代码:
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
def process_x(path):
img = Image.open(path)
img = img.resize((96,96))
img = img.convert('RGB')
img = np.array(img)
img = np.asarray(img, np.float32) / 255.0
#也可以进行进行一些数据数据增强的处理
return img
count =1
def generate_arrays_from_file(x_y):
#x_y 是我们的训练集包括标签,每一行的第一个是我们的图片路径,后面的是我们的独热化后的标签
global count
batch_size = 8
while 1:
batch_x = x_y[(count - 1) * batch_size:count * batch_size, 0]
batch_y = x_y[(count - 1) * batch_size:count * batch_size, 1:]
batch_x = np.array([process_x(img_path) for img_path in batch_x])
batch_y = np.array(batch_y).astype(np.float32)
print("count:"+str(count))
count = count+1
yield (batch_x, batch_y)
model = Sequential()
model.add(Dense(units=1000, activation='relu', input_dim=2))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])
x_y = []
model.fit_generator(generate_arrays_from_file(x_y),steps_per_epoch=10, epochs=2,max_queue_size=1,workers=1)
在理解上面代码之前我们需要首先了解yield的用法。
yield关键字:
我们先通过一个例子看一下yield的用法:
def foo():
print("starting...")
while True:
res = yield 4
print("res:",res)
g = foo()
print(next(g))
print("----------")
print(next(g))
运行结果:
starting...
4
----------
res: None
4
带yield的函数是一个生成器,而不是一个函数。因为foo函数中有yield关键字,所以foo函数并不会真的执行,而是先得到一个生成器的实例,当我们第一次调用next函数的时候,foo函数才开始行,首先先执行foo函数中的print方法,然后进入while循环,循环执行到yield时,yield其实相当于return,函数返回4,程序停止。所以我们第一次调用next(g)的输出结果是前面两行。
然后当我们再次调用next(g)时,这个时候是从上一次停止的地方继续执行,也就是要执行res的赋值操作,因为4已经在上一次执行被return了,随意赋值res为None,然后执行print(“res:”,res)打印res: None,再次循环到yield返回4,程序停止。
所以yield关键字的作用就是我们能够从上一次程序停止的地方继续执行,这样我们用作生成器的时候,就避免一次性读入数据造成内存不足的情况。
现在看到上面的示例代码:
generate_arrays_from_file函数就是我们的生成器,每次循环读取一个batch大小的数据,然后处理数据,并返回。x_y是我们的把路径和标签合并后的训练集,类似于如下形式:
['data/img\\fimg_4092.jpg' '0' '1' '0' '0' '0' ]
至于格式不一定要这样,可以是自己的格式,至于怎么处理,根于自己的格式,在process_x进行处理,这里因为是存放的图片路径,所以在process_x函数的主要作用就是读取图片并进行归一化等操作,也可以在这里定义自己需要进行的操作,例如对图像进行实时数据增强。
示例代码:
class BaseSequence(Sequence):
"""
基础的数据流生成器,每次迭代返回一个batch
BaseSequence可直接用于fit_generator的generator参数
fit_generator会将BaseSequence再次封装为一个多进程的数据流生成器
而且能保证在多进程下的一个epoch中不会重复取相同的样本
"""
def __init__(self, img_paths, labels, batch_size, img_size):
#np.hstack在水平方向上平铺
self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))
self.batch_size = batch_size
self.img_size = img_size
def __len__(self):
#math.ceil表示向上取整
#调用len(BaseSequence)时返回,返回的是每个epoch我们需要读取数据的次数
return math.ceil(len(self.x_y) / self.batch_size)
def preprocess_img(self, img_path):
img = Image.open(img_path)
resize_scale = self.img_size[0] / max(img.size[:2])
img = img.resize((self.img_size[0], self.img_size[0]))
img = img.convert('RGB')
img = np.array(img)
# 数据归一化
img = np.asarray(img, np.float32) / 255.0
return img
def __getitem__(self, idx):
batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]
batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
batch_y = np.array(batch_y).astype(np.float32)
print(batch_x.shape)
return batch_x, batch_y
#重写的父类Sequence中的on_epoch_end方法,在每次迭代完后调用。
def on_epoch_end(self):
#每次迭代后重新打乱训练集数据
np.random.shuffle(self.x_y)
在上面代码中,__len __和__getitem __,是我们重写的魔法方法,__len __是当我们调用len(BaseSequence)函数时调用,这里我们返回(样本总量/batch_size),供我们传入fit_generator中的steps_per_epoch参数;__getitem __可以让对象实现迭代功能,这样在将BaseSequence的对象传入fit_generator中后,不断执行generator就可循环的读取数据了。
举个例子说明一下getitem的作用:
class Animal:
def __init__(self, animal_list):
self.animals_name = animal_list
def __getitem__(self, index):
return self.animals_name[index]
animals = Animal(["dog","cat","fish"])
for animal in animals:
print(animal)
输出结果:
dog
cat
fish
并且使用Sequence类可以保证在多进程的情况下,每个epoch中的样本只会被训练一次。
参考文章:
Python.__getitem__方法
keras中文文档
python中yield的用法详解——最简单,最清晰的解释