fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
上面是官方的调用函数,具体英文解释大家可以看这里,这里是一个很完整的封装,其实在平常使用的时候,我们要用到的参数是generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None,
model_train = Model(inputs=[inputs_img, inputs_txt1, inputs_txt1no], outputs=loss)
batch_size = 50
# generate different list for input
def ahead_one(input_list):
b = input_list.pop(0)
return input_list
# wrap a tuple for input, three inputs , one output
def wrap_in_dictionary(image_feature, text_feature_train, text_feature_train2, y_label):
return [{'inputs_img': image_feature,
'inputs_txt': text_feature_train,
'inputs_txt_no': text_feature_train2},
{'all_loss': y_label}]
# make empty_batch to store data
def make_empty_batch():
image_batch = np.zeros((50, 1000))#could change for u data
text_batch = np.zeros((50, 1000))
textno_batch = np.zeros((50, 1000))
return image_batch, text_batch, textno_batch
# function to feed data to train the network
def flow(image_feature, text_feature_train, mode):
if mode == 'train':
image_feature = image_feature
text_feature_train = text_feature_train
if mode == 'val':
idx =[i for i in range(len(image_feature))]
image_feature = image_feature[idx]
text_feature_train = text_feature_train[idx]
empty_batch = make_empty_batch()
image_batch = empty_batch[0]
text_batch = empty_batch[1]
textno_batch = empty_batch[2]
batch_counter = 0
while True:
for i in range(len(image_feature)):
image_batch[batch_counter, :] = image_feature[i, :]
text_batch[batch_counter, :] = text_feature_train[i, :]
if batch_counter == 49:
index = [j_i for j_i in range(len(text_batch))]
for j in range(len(text_batch)):
if j != 0:
for j_i in range(j):
index = ahead_one(index)
textno_batch = text_batch[index]
y_label = np.zeros((len(text_batch),1))
yield_dictionary = wrap_in_dictionary(image_batch, text_batch, textno_batch, y_label)
yield yield_dictionary
empty_batch = make_empty_batch()
image_batch = empty_batch[0]
text_batch = empty_batch[1]
textno_batch = empty_batch[2]
batch_counter = 0
batch_counter = batch_counter + 1
model_train.fit_generator(generator=flow(features_train, text_feature_train, 'train'), steps_per_epoch = int(len(features_train)/batch_size),epochs=500,verbose=1, validation_data=flow(features_train, text_feature_train, 'val'), validation_steps = int(len(features_train)/batch_size))