Keras中fit_generator 的多个分支输入时,需注意generator的格式 以及 输入序列的顺序

需要注意迭代器 yeild返回不能是[x1,x2],y 这样,而是要完整的字典格式的: yield ({'input_1': x1, 'input_2': x2}, {'output': y})  。

这也不算坑 追进去 fit_generator也能看到示例

def generate_batch(x_train,y_train,batch_size,x_train2,randomFlag=True):
    ylen = len(y_train)
    loopcount = ylen // batch_size
    i=-1
    while True:
        if randomFlag:
            i = random.randint(0,loopcount-1)
        else:
            i=i+1
            i=i%loopcount

        yield ({'lstmInput': x_train[i*batch_size:(i+1)*batch_size], 
                'bgInput': x_train2[i*batch_size:(i+1)*batch_size]}, 
            {'prediction': y_train[i*batch_size:(i+1)*batch_size]})  

ps: 因为要是tuple yield后的括号不能省   
 

需注意的坑1是,validation data中如果用【】组成数组进行输入,是要按顺序的,按编译model前的设置model = Model(inputs=[simInput,lstmInput,bgInput], outputs=predictions),中数组的顺序来编译

需注意的坑2是,多输入input时,以后都用 inputs1=Input(batch_shape=(batchSize,TPeriod,dimIn,),name='input1LSTM')指定batchSize,不然跟stateful lstm结合时,会提示不匹配。

history=model.fit_generator(generate_batch(trainX,trainY,batchSize,trainX2),
            steps_per_epoch=len(trainX)//batchSize,
            validation_data=([testX,testX2],testY),
            epochs=epochs,
           callbacks=[tensorboard,checkpoint],initial_epoch=0,verbose=1)  # Fit the LSTM network/拟合LSTM网络
 

你可能感兴趣的:(keras)