之前写keras的时候,都是直接model.fit(),后来发现这样不节省内存,尤其是在输入数据本身不大,但是内部要进行排列组合的时候就显得特别有用,这里记录一下fit_generator的用法:
fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
上面是官方的调用函数,具体英文解释大家可以看这里,这里是一个很完整的封装,其实在平常使用的时候,我们要用到的参数是generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None,
,这里的validation_data
默认是空,一般都还是要给一些的,所以我这里给的例子就是有validation_data
的。
model_train = Model(inputs=[inputs_img, inputs_txt1, inputs_txt1no], outputs=loss)
这是我的模型,三个输入,一个输出,至于中间具体的数值,这里忽略,只是给出形式的抽象解释。
batch_size = 50
# generate different list for input
def ahead_one(input_list):
b = input_list.pop(0)
input_list.append(b)
return input_list
# wrap a tuple for input, three inputs , one output
def wrap_in_dictionary(image_feature, text_feature_train, text_feature_train2, y_label):
return [{'inputs_img': image_feature,
'inputs_txt': text_feature_train,
'inputs_txt_no': text_feature_train2},
{'all_loss': y_label}]
# make empty_batch to store data
def make_empty_batch():
image_batch = np.zeros((50, 1000))#could change for u data
text_batch = np.zeros((50, 1000))
textno_batch = np.zeros((50, 1000))
return image_batch, text_batch, textno_batch
# function to feed data to train the network
def flow(image_feature, text_feature_train, mode):
if mode == 'train':
image_feature = image_feature
text_feature_train = text_feature_train
if mode == 'val':
idx =[i for i in range(len(image_feature))]
random.shuffle(idx)
image_feature = image_feature[idx]
text_feature_train = text_feature_train[idx]
empty_batch = make_empty_batch()
image_batch = empty_batch[0]
text_batch = empty_batch[1]
textno_batch = empty_batch[2]
batch_counter = 0
while True:
for i in range(len(image_feature)):
image_batch[batch_counter, :] = image_feature[i, :]
text_batch[batch_counter, :] = text_feature_train[i, :]
if batch_counter == 49:
index = [j_i for j_i in range(len(text_batch))]
for j in range(len(text_batch)):
if j != 0:
for j_i in range(j):
index = ahead_one(index)
textno_batch = text_batch[index]
y_label = np.zeros((len(text_batch),1))
yield_dictionary = wrap_in_dictionary(image_batch, text_batch, textno_batch, y_label)
yield yield_dictionary
empty_batch = make_empty_batch()
image_batch = empty_batch[0]
text_batch = empty_batch[1]
textno_batch = empty_batch[2]
batch_counter = 0
batch_counter = batch_counter + 1
#注意调用的格式,只用了六个参数
model_train.fit_generator(generator=flow(features_train, text_feature_train, 'train'), steps_per_epoch = int(len(features_train)/batch_size),epochs=500,verbose=1, validation_data=flow(features_train, text_feature_train, 'val'), validation_steps = int(len(features_train)/batch_size))