keras fit_generator 节省内存 例子

之前写keras的时候,都是直接model.fit(),后来发现这样不节省内存,尤其是在输入数据本身不大,但是内部要进行排列组合的时候就显得特别有用,这里记录一下fit_generator的用法:

fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

上面是官方的调用函数,具体英文解释大家可以看这里,这里是一个很完整的封装,其实在平常使用的时候,我们要用到的参数是generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None,,这里的validation_data默认是空,一般都还是要给一些的,所以我这里给的例子就是有validation_data的。

model_train = Model(inputs=[inputs_img, inputs_txt1, inputs_txt1no], outputs=loss)

这是我的模型,三个输入,一个输出,至于中间具体的数值,这里忽略,只是给出形式的抽象解释。

batch_size = 50
# generate different list for input
def ahead_one(input_list):
  b = input_list.pop(0)
  input_list.append(b)
  return input_list

# wrap a tuple for input, three inputs , one output
def wrap_in_dictionary(image_feature, text_feature_train, text_feature_train2, y_label):
   return [{'inputs_img': image_feature,
            'inputs_txt': text_feature_train,
            'inputs_txt_no': text_feature_train2},
            {'all_loss': y_label}]

# make empty_batch to store data
def make_empty_batch():
  image_batch = np.zeros((50, 1000))#could change for u data
  text_batch = np.zeros((50, 1000))
  textno_batch = np.zeros((50, 1000))
  return image_batch, text_batch, textno_batch

# function to feed data to train the network
def flow(image_feature, text_feature_train, mode):
  if mode == 'train':
     image_feature = image_feature
     text_feature_train = text_feature_train
  if mode == 'val':
     idx =[i for i in range(len(image_feature))]
     random.shuffle(idx)
     image_feature = image_feature[idx]
     text_feature_train = text_feature_train[idx]
  empty_batch = make_empty_batch()
  image_batch = empty_batch[0]
  text_batch = empty_batch[1]
  textno_batch = empty_batch[2]
  batch_counter = 0
  while True:
    for i in range(len(image_feature)):
      image_batch[batch_counter, :] = image_feature[i, :]
      text_batch[batch_counter, :] = text_feature_train[i, :]
      if batch_counter == 49:
        index = [j_i for j_i in range(len(text_batch))]
        for j in range(len(text_batch)):
          if j != 0:
            for j_i in range(j):
              index = ahead_one(index)
            textno_batch = text_batch[index]
            y_label = np.zeros((len(text_batch),1))
            yield_dictionary = wrap_in_dictionary(image_batch, text_batch, textno_batch, y_label)
            yield yield_dictionary
        empty_batch = make_empty_batch()
        image_batch = empty_batch[0]
        text_batch = empty_batch[1]
        textno_batch = empty_batch[2]
        batch_counter = 0
      batch_counter = batch_counter + 1
#注意调用的格式,只用了六个参数
model_train.fit_generator(generator=flow(features_train, text_feature_train, 'train'), steps_per_epoch = int(len(features_train)/batch_size),epochs=500,verbose=1, validation_data=flow(features_train, text_feature_train, 'val'), validation_steps = int(len(features_train)/batch_size))

你可能感兴趣的:(python)