使用sklearn包装器实例(重点:sklearn的GridSearchCV方法如何使用)
Builds simple CNN models on MNIST and uses sklearn's GridSearchCV to find best model
基于MNIST(数据集)建立简单CNN(卷积神经网络)模型,使用sklearn的GridSearchCV方法找出最佳模型
GridSearchCV,意义是自动调参,参数输进去,就能给出最优化的结果和参数。但适合于小数据集,数据的量级大了,很难得出结果。
官网介绍
Keras实例目录
代码注释
'''Example of how to use sklearn wrapper
使用sklearn包装器实例(重点:sklearn的GridSearchCV方法如何使用)
Builds simple CNN models on MNIST and uses sklearn's GridSearchCV to find best model
基于MNIST(数据集)建立简单CNN(卷积神经网络)模型,使用sklearn的GridSearchCV方法找出最佳模型
GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数。但是这个方法适合于小数据集,
数据的量级上去了,很难得出结果。
http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
'''
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.wrappers.scikit_learn import KerasClassifier
from keras import backend as K
from sklearn.grid_search import GridSearchCV
num_classes = 10
# input image dimensions
# 输入图像维度
img_rows, img_cols = 28, 28
# load training data and do basic data normalization
# 加载训练集数据并做基础的数据归一化处理
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':# Theano框架,图像通道在前
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else: # TensorFlow框架,图像通道在后
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
# 类别向量转为多分类矩阵
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
def make_model(dense_layer_sizes, filters, kernel_size, pool_size):
'''Creates model comprised of 2 convolutional layers followed by dense layers
创建模型组成:2个卷积层后跟一个全连接层
dense_layer_sizes: List of layer sizes.
This list has one number for each layer
dense_layer_sizes: 层大小列表
列表每个层都有一个数字
filters: Number of convolutional filters in each convolutional layer
filters: 每个卷积层的卷积核数量
kernel_size: Convolutional kernel size
kernel_size: 卷积核大小
pool_size: Size of pooling area for max pooling
pool_size: 最大池化的池化区域大小
'''
model = Sequential()
model.add(Conv2D(filters, kernel_size,
padding='valid',
input_shape=input_shape))
model.add(Activation('relu'))
model.add(Conv2D(filters, kernel_size))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=pool_size))
model.add(Dropout(0.25))
model.add(Flatten())
for layer_size in dense_layer_sizes:
model.add(Dense(layer_size))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adadelta',
metrics=['accuracy'])
return model
dense_size_candidates = [[32], [64], [32, 32], [64, 64]]
my_classifier = KerasClassifier(make_model, batch_size=32)
validator = GridSearchCV(my_classifier,
param_grid={'dense_layer_sizes': dense_size_candidates,
# epochs is avail for tuning even when not
# an argument to model building function
# 即使在没有建立模型功能的参数时,也有必要进行调整。
'epochs': [3, 6],
'filters': [8],
'kernel_size': [3],
'pool_size': [2]},
scoring='neg_log_loss',
n_jobs=1)
validator.fit(x_train, y_train)
print('The parameters of the best model are: ')
print(validator.best_params_)
# validator.best_estimator_ returns sklearn-wrapped version of best model.
# validator.best_estimator_返回sklearn-wrapped版本的最佳模型。
# validator.best_estimator_.model returns the (unwrapped) keras model
# validator.best_estimator_.model 返回(展开的)keras模型
best_model = validator.best_estimator_.model
metric_names = best_model.metrics_names
metric_values = best_model.evaluate(x_test, y_test)
for metric, value in zip(metric_names, metric_values):
print(metric, ': ', value)
代码执行
Keras详细介绍
英文:https://keras.io/
中文:http://keras-cn.readthedocs.io/en/latest/
实例下载
https://github.com/keras-team/keras
https://github.com/keras-team/keras/tree/master/examples
完整项目下载
方便没积分童鞋,请加企鹅452205574,共享文件夹。
包括:代码、数据集合(图片)、已生成model、安装库文件等。