通过Keras的包装类,借助Scikit-Learn的网格搜索算法评估神经网络模型的不同配置,并找到最佳评估性能的参数组合。
在Scikit-Learn中的GridSearchCV需要一个字典类型的字段作为需要调参的参数,默认采用3折交叉验证的方法来评估算法。
这里有四个参数需要调参,因此会产生4*3个模型。
代码如下:
"""
通过Scikit-learn中的GridSearchCV进行自动调参
耗时很久,很多情况下不常用
"""
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasClassifier
#构建模型
#这里的参数必须要有init才可以!!!不然会报错。
def create_model(optimizer='rmsprop', init='glorot_uniform'):
#构建模型
model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer=init, activation='relu'))
model.add(Dense(8, kernel_initializer=init, activation='relu'))
model.add(Dense(1, kernel_initializer=init, activation='sigmoid'))
#编译模型
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
return model
seed = 7 #设置随机种子
np.random.seed(seed)
#导入数据
dataset = np.loadtxt(r'F:\Python\pycharm\keras_deeplearning\datasets\PimaIndiansdiabetes.csv', delimiter=',', skiprows=1)
#分割输入变量x和输出变量Y
x = dataset[:, 0:8]
Y = dataset[:, 8]
#创建模型,,迭代——参数为(模型,时期,批处理大小,verbose=0作用:关闭模型的fit()和evaluate()的详细输出)
model = KerasClassifier(build_fn=create_model, verbose=0)
#创建需要调参的参数
param_grid = {}
param_grid['optimizer'] = ['rmsprop', 'adam']
param_grid['init'] = ['glorot_uniform', 'normal', 'uniform']
param_grid['epochs'] = [50,100,150,200]
param_grid['batch_size'] = [5,10,20]
#调参
grid = GridSearchCV(estimator=model, param_grid=param_grid)
results = grid.fit(x,Y)
#输出结果
print('Best: %f using %s' % (results.best_score_, results.best_params_))
means = results.cv_results_['mean_test_score']
stds = results.cv_results_['std_test_score']
params = results.cv_results_['params']
for mean, std, param in zip(means,stds,params):
print('%f (%f) with: %r' % (mean, std, param))
这里简要科普一下zip()函数:
备注:zip():
>>>a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b) # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
>>> zip(a,c) # 元素个数与最短的列表一致
[(1, 4), (2, 5), (3, 6)]
>>> zip(*zipped) # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式
[(1, 2, 3), (4, 5, 6)]