需要注意迭代器 yeild返回不能是[x1,x2],y 这样,而是要完整的字典格式的: yield ({'input_1': x1, 'input_2': x2}, {'output': y}) 。
这也不算坑 追进去 fit_generator也能看到示例
def generate_batch(x_train,y_train,batch_size,x_train2,randomFlag=True):
ylen = len(y_train)
loopcount = ylen // batch_size
i=-1
while True:
if randomFlag:
i = random.randint(0,loopcount-1)
else:
i=i+1
i=i%loopcount
yield ({'lstmInput': x_train[i*batch_size:(i+1)*batch_size],
'bgInput': x_train2[i*batch_size:(i+1)*batch_size]},
{'prediction': y_train[i*batch_size:(i+1)*batch_size]})
ps: 因为要是tuple yield后的括号不能省
需注意的坑1是,validation data中如果用【】组成数组进行输入,是要按顺序的,按编译model前的设置model = Model(inputs=[simInput,lstmInput,bgInput], outputs=predictions),中数组的顺序来编译
需注意的坑2是,多输入input时,以后都用 inputs1=Input(batch_shape=(batchSize,TPeriod,dimIn,),name='input1LSTM')指定batchSize,不然跟stateful lstm结合时,会提示不匹配。
history=model.fit_generator(generate_batch(trainX,trainY,batchSize,trainX2),
steps_per_epoch=len(trainX)//batchSize,
validation_data=([testX,testX2],testY),
epochs=epochs,
callbacks=[tensorboard,checkpoint],initial_epoch=0,verbose=1) # Fit the LSTM network/拟合LSTM网络