gridsearchCV暴力训练多输入keras模型报错!

1. 问题描述

使用keras框架写了深度学习框架,尝试使用sklearn.gridsearchCV对模型的超参进行一波暴力调参,为达到该目的,需要以下几步:

  1. 将keras模型转换为grid seach可调参的形式。
    model = scikit_learn.KerasClassifier(build_fn=build_model(k,adam), verbose=0)
  2. 训练模型。本文中模型为多输入结构,类似于下图(盗图,此处仅为说清概念),输入分为input1和input2。
gridsearchCV暴力训练多输入keras模型报错!_第1张图片

模型训练fit(x_train, y_train),其中x_train为2维输入,如下。

x_train = [X_train1, X_train2]
grid_result = grid.fit(x_train, y_train)

运行代码,fit时报错为

found input variables with inconsistent numbers of samples

即要求x_train.shape[0] = y_train.shape[0]=样本数,但若为n个输入,x_train.shape[0] = n != 样本数,这个矛盾必不可解决啊!!(除非改库或者改变输入格式)

2. 解决方案

寻寻觅觅寻寻觅觅,尝试了n多种方法,最后终于找到了一个完美的解决方案(keras lambda层),哦吼吼。

keras.layers.core.Lambda(function, output_shape=None, mask=None, arguments=None)
  • keras lambda层:用以对上一层的输出施以任何Theano/TensorFlow表达式,对流经该层的数据做个变换,而这个变换本身没有什么需要学习的参数。
  • 在build_model中,input层之前添加lambda层(如下图),而fit(x_train, y_train)中的x_train不需要为n维输入,也就避免了第一部分所述矛盾。
gridsearchCV暴力训练多输入keras模型报错!_第2张图片

build_model中代码如下:

# 导入库
from keras.layers.core import Lambda

# 使用lambda对输入数据取值
# 其中X_train1为x_train的0:3列,X_train2为X_train的4:7列。
X_train1 = Lambda(slice,output_shape=(4,1),arguments={'index':0})(inputs)
X_train2 = Lambda(slice,output_shape=(4,1),arguments={'index':4})(inputs)

完结撒花!

你可能感兴趣的:(深度学习,学习,网格调参,gridsearchCV,keras模型)